diff --git a/3rdparty/backend-r22.12/.clang-format b/3rdparty/backend-r22.12/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..98c649734c29e0b1d134dae65be9bc08a14b4ba5 --- /dev/null +++ b/3rdparty/backend-r22.12/.clang-format @@ -0,0 +1,37 @@ +--- +BasedOnStyle: Google + +IndentWidth: 2 +ContinuationIndentWidth: 4 +UseTab: Never +MaxEmptyLinesToKeep: 2 + +SortIncludes: true +CompactNamespaces: true +ReflowComments: true + +DerivePointerAlignment: false +PointerAlignment: Left + +AllowShortIfStatementsOnASingleLine: false +AllowShortBlocksOnASingleLine: false +AllowShortFunctionsOnASingleLine: Inline + +AlwaysBreakAfterReturnType: TopLevelDefinitions +AlignAfterOpenBracket: AlwaysBreak +BreakBeforeBraces: Custom +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: true + AfterNamespace: false + AfterStruct: false + AfterUnion: false + BeforeCatch: true + +BinPackArguments: true +BinPackParameters: true +ConstructorInitializerAllOnOneLineOrOnePerLine: false + +IndentCaseLabels: true \ No newline at end of file diff --git a/3rdparty/backend-r22.12/.gitignore b/3rdparty/backend-r22.12/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0e9f099a2eef4742716637e3cce3a45f7053b021 --- /dev/null +++ b/3rdparty/backend-r22.12/.gitignore @@ -0,0 +1,3 @@ +/build +/.vscode +*.so diff --git a/3rdparty/backend-r22.12/LICENSE b/3rdparty/backend-r22.12/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..1b34054e482218d517a8b190ee112ee99740f976 --- /dev/null +++ b/3rdparty/backend-r22.12/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2018-2020, NVIDIA CORPORATION. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of NVIDIA CORPORATION nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/3rdparty/backend-r22.12/README.md b/3rdparty/backend-r22.12/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0ca36f1edf54a41ca26a7db3443323e500ccdf7f --- /dev/null +++ b/3rdparty/backend-r22.12/README.md @@ -0,0 +1,540 @@ + + +[![License](https://img.shields.io/badge/License-BSD3-lightgrey.svg)](https://opensource.org/licenses/BSD-3-Clause) + +# Triton Inference Server Backend + +A Triton *backend* is the implementation that executes a model. A +backend can be a wrapper around a deep-learning framework, like +PyTorch, TensorFlow, TensorRT or ONNX Runtime. Or a backend can be +custom C/C++ logic performing any operation (for example, image +pre-processing). + +This repo contains documentation on Triton backends and also source, +scripts and utilities for creating Triton backends. You do not need to +use anything provided in this repo to create a Triton backend but you +will likely find its contents useful. + +## Frequently Asked Questions + +Full documentation is included below but these shortcuts can help you +get started in the right direction. + +### Where can I ask general questions about Triton and Triton backends? + +Be sure to read all the information below as well as the [general +Triton +documentation](https://github.com/triton-inference-server/server#triton-inference-server) +available in the main +[server](https://github.com/triton-inference-server/server) repo. If +you don't find your answer there you can ask questions on the main +Triton [issues +page](https://github.com/triton-inference-server/server/issues). + +### Where can I find all the backends that are available for Triton? + +Anyone can develop a Triton backend, so it isn't possible for us to +know about all available backends. But the Triton project does provide +a set of supported backends that are tested and updated with each +Triton release. + +**TensorRT**: The TensorRT backend is used to execute TensorRT +models. The +[server](https://github.com/triton-inference-server/tensorrt_backend) +repo contains the source for the backend. + +**ONNX Runtime**: The ONNX Runtime backend is used to execute ONNX +models. The +[onnxruntime_backend](https://github.com/triton-inference-server/onnxruntime_backend) +repo contains the documentation and source for the backend. + +**TensorFlow**: The TensorFlow backend is used to execute TensorFlow +models in both GraphDef and SavedModel formats. The same backend is +used to execute both TensorFlow 1 and TensorFlow 2 models. The +[tensorflow_backend](https://github.com/triton-inference-server/tensorflow_backend) +repo contains the documentation and source for the backend. + +**PyTorch**: The PyTorch backend is used to execute TorchScript +models. The +[pytorch_backend](https://github.com/triton-inference-server/pytorch_backend) +repo contains the documentation and source for the backend. + +**OpenVINO**: The OpenVINO backend is used to execute +[OpenVINO](https://docs.openvinotoolkit.org/latest/index.html) +models. The +[openvino_backend](https://github.com/triton-inference-server/openvino_backend) +repo contains the documentation and source for the backend. + +**Python**: The Python backend allows you to write your model logic in +Python. For example, you can use this backend to execute pre/post +processing code written in Python, or to execute a PyTorch Python +script directly (instead of first converting it to TorchScript and +then using the PyTorch backend). The +[python_backend](https://github.com/triton-inference-server/python_backend) +repo contains the documentation and source for the backend. + +**DALI**: [DALI](https://github.com/NVIDIA/DALI) is a collection of +highly optimized building blocks and an execution engine that +accelerates the pre-processing of the input data for deep learning +applications. The DALI backend allows you to execute your DALI +pipeline within Triton. The +[dali_backend](https://github.com/triton-inference-server/dali_backend) +repo contains the documentation and source for the backend. + +**FIL**: The FIL ([Forest Inference +Library](https://github.com/rapidsai/cuml/tree/branch-21.10/python/cuml/fil)) +backend is used to execute a variety of tree-based ML models, including +XGBoost models, LightGBM models, Scikit-Learn random forest models, and cuML +random forest models. The +[fil_backend](https://github.com/triton-inference-server/fil_backend) repo +contains the documentation and source for the backend. + +**Important Note!** Not all the above backends are supported on every platform +supported by Triton. Look at the +[Backend-Platform Support Matrix](docs/backend_platform_support_matrix.md) +to learn about the same. + +### How can I develop my own Triton backend? + +First you probably want to ask on the main Triton [issues +page](https://github.com/triton-inference-server/server/issues) to +make sure you are not duplicating a backend that already exists. Then +follow the [tutorial](examples/README.md) to learn how to create your +first simple Triton backend and incrementally improve it to add more +features. You should also read the complete documentation on [Triton +backends](#backends). + +### Can I add (or remove) a backend to an existing Triton installation? + +Yes. See [Backend Shared Library](#backend-shared-library) for general +information about how the shared library implementing a backend is +managed by Triton, and [Triton with Unsupported and Custom +Backends](https://github.com/triton-inference-server/server/blob/main/docs/customization_guide/compose.md#triton-with-unsupported-and-custom-backends) +for documentation on how to add your backend to the released Triton +Docker image. For a standard install the globally available backends +are in /opt/tritonserver/backends. + +### What about backends developed using the "legacy custom backend" API. + +The legacy custom API is removed from Triton. If you have custom +backends that you developed using this older API you must port them to +the new [Triton Backend API](#triton-backend-api). + +## Backends + +A Triton *backend* is the implementation that executes a model. A +backend can be a wrapper around a deep-learning framework, like +PyTorch, TensorFlow, TensorRT, ONNX Runtime or OpenVINO. A backend can +also implement any functionality you want as long as it adheres to the +[backend API](#triton-backend-api). Triton uses this API to send +requests to the backend for execution and the backend uses the API to +communicate with Triton. + +Every model must be associated with a backend. A model's backend is +specified in the model's configuration using the 'backend' setting. +For using TensorRT backend, the value of this setting should be *tensorrt*. +Similarly, for using PyTorch, ONNX and TensorFlow Backends, the `backend` +field should be set to *pytorch*, *onnxruntime* or *tensorflow* respectively. +For all other backends, 'backend' must be set to the name of the backend. + +### Backend Shared Library + +Each backend must be implemented as a shared library and the name of +the shared library must be *libtriton_\.so*. For +example, if the name of the backend is "mybackend", a model indicates +that it uses the backend by setting the model configuration 'backend' +setting to "mybackend", and Triton looks for *libtriton_mybackend.so* +as the shared library that implements the backend. The +[tutorial](examples/README.md) shows examples of how to build your +backend logic into the appropriate shared library. + +For a model, *M* that specifies backend *B*, Triton searches for the +backend shared library in the following places, in this order: + +* \/M/\/libtriton_B.so + +* \/M/libtriton_B.so + +* \/B/libtriton_B.so + +Where \ is by default +/opt/tritonserver/backends. The --backend-directory flag can be used +to override the default. + +Typically you will install your backend into the global backend +directory. For example, if using Triton Docker images you can follow +the instructions in [Triton with Unsupported and Custom +Backends](https://github.com/triton-inference-server/server/blob/main/docs/customization_guide/compose.md#triton-with-unsupported-and-custom-backends). Continuing +the example of a backend names "mybackend", you would install into the +Triton image as: + +``` +/opt/ + tritonserver/ + backends/ + mybackend/ + libtriton_mybackend.so + ... # other files needed by mybackend +``` + +### Triton Backend API + +A Triton backend must implement the C interface defined in +[tritonbackend.h](https://github.com/triton-inference-server/core/tree/main/include/triton/core/tritonbackend.h). The +following abstractions are used by the API. + +#### TRITONBACKEND_Backend + +A TRITONBACKEND_Backend object represents the backend itself. The +same backend object is shared across all models that use the +backend. The associated API, like TRITONBACKEND_BackendName, is used +to get information about the backend and to associate a user-defined +state with the backend. + +A backend can optionally implement TRITONBACKEND_Initialize and +TRITONBACKEND_Finalize to get notification of when the backend object +is created and destroyed (for more information see [backend +lifecycles](#backend-lifecycles)). + +#### TRITONBACKEND_Model + +A TRITONBACKEND_Model object represents a model. Each model loaded by +Triton is associated with a TRITONBACKEND_Model. Each model can use +the TRITONBACKEND_ModelBackend API to get the backend object +representing the backend that is used by the model. + +The same model object is shared across all instances of that +model. The associated API, like TRITONBACKEND_ModelName, is used to +get information about the model and to associate a user-defined state +with the model. + +Most backends will implement TRITONBACKEND_ModelInitialize and +TRITONBACKEND_ModelFinalize to initialize the backend for a given +model and to manage the user-defined state associated with the model +(for more information see [backend lifecycles](#backend-lifecycles)). + +The backend must take into account threading concerns when +implementing TRITONBACKEND_ModelInitialize and +TRITONBACKEND_ModelFinalize. Triton will not perform multiple +simultaneous calls to these functions for a given model; however, if a +backend is used by multiple models Triton may simultaneously call the +functions with a different thread for each model. As a result, the +backend must be able to handle multiple simultaneous calls to the +functions. Best practice for backend implementations is to use only +function-local and model-specific user-defined state in these +functions, as is shown in the [tutorial](examples/README.md). + +#### TRITONBACKEND_ModelInstance + +A TRITONBACKEND_ModelInstance object represents a model +*instance*. Triton creates one or more instances of the model based on +the *instance_group* settings specified in the model +configuration. Each of these instances is associated with a +TRITONBACKEND_ModelInstance object. + +The only function that the backend must implement is +TRITONBACKEND_ModelInstanceExecute. The +TRITONBACKEND_ModelInstanceExecute function is called by Triton to +perform inference/computation on a batch of inference requests. Most +backends will also implement TRITONBACKEND_ModelInstanceInitialize +and TRITONBACKEND_ModelInstanceFinalize to initialize the backend for +a given model instance and to manage the user-defined state associated +with the model (for more information see [backend +lifecycles](#backend-lifecycles)). + +The backend must take into account threading concerns when +implementing TRITONBACKEND_ModelInstanceInitialize, +TRITONBACKEND_ModelInstanceFinalize and +TRITONBACKEND_ModelInstanceExecute. Triton will not perform multiple +simultaneous calls to these functions for a given model instance; +however, if a backend is used by a model with multiple instances or by +multiple models Triton may simultaneously call the functions with a +different thread for each model instance. As a result, the backend +must be able to handle multiple simultaneous calls to the +functions. Best practice for backend implementations is to use only +function-local and model-specific user-defined state in these +functions, as is shown in the [tutorial](examples/README.md). + +#### TRITONBACKEND_Request + +A TRITONBACKEND_Request object represents an inference request made +to the model. The backend takes ownership of the request object(s) in +TRITONBACKEND_ModelInstanceExecute and must release each request by +calling TRITONBACKEND_RequestRelease. However, the ownership of request +object is returned back to Triton in case TRITONBACKEND_ModelInstanceExecute +returns an error. See [Inference Requests and Responses](#inference-requests-and-responses) +for more information about request lifecycle. + +The Triton Backend API allows the backend to get information about the +request as well as the input and request output tensors of the +request. Each request input is represented by a TRITONBACKEND_Input +object. + +#### TRITONBACKEND_Response + +A TRITONBACKEND_Response object represents a response sent by the +backend for a specific request. The backend uses the response API to +set the name, shape, datatype and tensor values for each output tensor +included in the response. The response can indicate either a failed or +a successful request. See [Inference Requests and +Responses](#inference-requests-and-responses) for more information +about request-response lifecycle. + +### Backend Lifecycles + +A backend must carefully manage the lifecycle of the backend itself, +the models and model instances that use the backend and the inference +requests that execute on the model instances using the backend. + +#### Backend and Model + +Backend, model and model instance initialization is triggered when +Triton loads a model. + +* If the model requires a backend that is not already in use by an + already loaded model, then: + + * Triton [loads the shared library](#backend-shared-library) that + implements the backend required by the model. + + * Triton creates the TRITONBACKEND_Backend object that represents + the backend. + + * Triton calls TRITONBACKEND_Initialize if it is implemented in the + backend shared library. TRITONBACKEND_Initialize should not return + until the backend is completely initialized. If + TRITONBACKEND_Initialize returns an error, Triton will report that + the model failed to load. + +* Triton creates the TRITONBACKEND_Model object that represents the + model. Triton calls TRITONBACKEND_ModelInitialize if it is + implemented in the backend shared library. + TRITONBACKEND_ModelInitialize should not return until the backend + is completely initialized for the model. If + TRITONBACKEND_ModelInitialize returns an error, Triton will show + that the model failed to load. + +* For each model instance specified for the model in the model + configuration: + + * Triton creates the TRITONBACKEND_ModelInstance object that + represents the model instance. + + * Triton calls TRITONBACKEND_ModelInstanceInitialize if it is + implemented in the backend shared library. + TRITONBACKEND_ModelInstanceInitialize should not return until the + backend is completely initialized for the instance. If + TRITONBACKEND_ModelInstanceInitialize returns an error, Triton + will show that the model failed to load. + +Backend, model and model instance finalization is triggered when +Triton unloads a model. + +* For each model instance: + + * Triton calls TRITONBACKEND_ModelInstanceFinalize if it is + implemented in the backend shared library. + TRITONBACKEND_ModelInstanceFinalize should not return until the + backend is completely finalized, including stopping any threads + create for the model instance and freeing any user-defined state + created for the model instance. + + * Triton destroys the TRITONBACKEND_ModelInstance object that + represents the model instance. + +* Triton calls TRITONBACKEND_ModelFinalize if it is implemented in the + backend shared library. TRITONBACKEND_ModelFinalize should not + return until the backend is completely finalized, including stopping + any threads create for the model and freeing any user-defined state + created for the model. + +* Triton destroys the TRITONBACKEND_Model object that represents the + model. + +* Even if no other loaded model requires the backend, Triton does not + finalize and unload the backend until the tritonserver process is + exiting. When the tritonserver process exits: + + * Triton calls TRITONBACKEND_Finalize if it is implemented in the + backend shared library. TRITONBACKEND_ModelFinalize should not + return until the backend is completely finalized, including + stopping any threads create for the backend and freeing any + user-defined state created for the backend. + + * Triton destroys the TRITONBACKEND_Backend object that represents + the backend. + +#### Inference Requests and Responses + +Triton calls TRITONBACKEND_ModelInstanceExecute to execute inference +requests on a model instance. Each call to +TRITONBACKEND_ModelInstanceExecute communicates a batch of requests +to execute and the instance of the model that should be used to +execute those requests. The backend should not allow the caller +thread to return from TRITONBACKEND_ModelInstanceExecute until that +instance is ready to handle another set of requests. Typically this +means that the TRITONBACKEND_ModelInstanceExecute function will +create responses and release the requests before returning. However, +in case TRITONBACKEND_ModelInstanceExecute returns an error, the ownership +of requests is transferred back to Triton which will then be responsible +for releasing them. Therefore, in the case where TRITONBACKEND_ModelInstanceExecute +returns an error, the backend must not retain references to the requests +or access them in any way. For more detailed description of request/response +lifetimes, study the documentation of TRITONBACKEND_ModelInstanceExecute in +[tritonbackend.h](https://github.com/triton-inference-server/core/blob/main/include/triton/core/tritonbackend.h). + +##### Single Response + +Most backends will create a single response for each request. For that +kind of backend, executing a single inference request requires the +following steps: + +* Create a response for the request using TRITONBACKEND_ResponseNew. + +* For each request input tensor use TRITONBACKEND_InputProperties to + get shape and datatype of the input as well as the buffer(s) + containing the tensor contents. + +* For each output tensor which the request expects to be returned, use + TRITONBACKEND_ResponseOutput to create the output tensor of the + required datatype and shape. Use TRITONBACKEND_OutputBuffer to get a + pointer to the buffer where the tensor's contents should be written. + +* Use the inputs to perform the inference computation that produces + the requested output tensor contents into the appropriate output + buffers. + +* Optionally set parameters in the response. + +* Send the response using TRITONBACKEND_ResponseSend. + +* Release the request using TRITONBACKEND_RequestRelease. + +For a batch of requests the backend should attempt to combine the +execution of the individual requests as much as possible to increase +performance. + +##### Decoupled Responses + +It is also possible for a backend to send multiple responses for a +request or not send any responses for a request. A backend may also +send responses out-of-order relative to the order that the request +batches are executed. Such backends are called *decoupled* backends. +The decoupled backends use one `ResponseFactory` object per request to keep +creating and sending any number of responses for the request. For this +kind of backend, executing a single inference request typically requires +the following steps: + +* For each request input tensor use TRITONBACKEND_InputProperties to + get shape and datatype of the input as well as the buffer(s) + containing the tensor contents. + +* Create a `ResponseFactory` object for the request using + TRITONBACKEND_ResponseFactoryNew. + + 1. Create a response from the `ResponseFactory` object using + TRITONBACKEND_ResponseNewFromFactory. As long as you have + `ResponseFactory` object you can continue creating responses. + + 2. For each output tensor which the request expects to be returned, use + TRITONBACKEND_ResponseOutput to create the output tensor of the + required datatype and shape. Use TRITONBACKEND_OutputBuffer to get a + pointer to the buffer where the tensor's contents should be written. + + 3. Use the inputs to perform the inference computation that produces + the requested output tensor contents into the appropriate output + buffers. + + 4. Optionally set parameters in the response. + + 5. Send the response using TRITONBACKEND_ResponseSend. If this is the + last request then use TRITONSERVER_ResponseCompleteFlag with + TRITONBACKEND_ResponseSend. Otherwise continue with Step 1 for + sending next request + +* Release the request using TRITONBACKEND_RequestRelease. + +###### Special Cases + +The decoupled API is powerful and supports various special cases: + +* If the backend should not send any response for the request, + TRITONBACKEND_ResponseFactorySendFlags can be used to send + TRITONSERVER_RESPONSE_COMPLETE_FINAL using the `ResponseFactory`. + +* The model can also send responses out-of-order in which it received + requests. + +* The backend can copy out the contents of the input buffer(s) if + request is to be released before the contents are completely + consumed to generate responses. After copy, the request can be + released anytime before exiting TRITONBACKEND_ModelInstanceExecute. + The copies and `ResponseFactory` object can be passed to a separate + thread in backend. This means main caller thread can exit from + TRITONBACKEND_ModelInstanceExecute and the backend can still continue + generating responses as long as it holds `ResponseFactory` object. + + +The [repeat example](examples/README.md) demonstrates full power of +what can be acheived from decoupled API. + + +Study documentation of these TRTIONBACKEND_* functions in +[tritonbackend.h](https://github.com/triton-inference-server/core/blob/main/include/triton/core/tritonbackend.h) +for more details on these APIs. Read +[Decoupled Backends and Models](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/decoupled_models.md) +for more details on how to host a decoupled model. + +## Build the Backend Utilities + +The source in this repo builds into a single "backend utilities" +library that is useful when building backends. You don't need to use +these utilities but they will be helpful for most backends. + +Typically you don't need to build this repo directly but instead you +can include it in the build of your backend as is shown in the +CMakeLists.txt files of the [tutorial examples](examples/README.md). + +To build and install in a local directory use the following commands. + +``` +$ mkdir build +$ cd build +$ cmake -DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install .. +$ make install +``` + +The following required Triton repositories will be pulled and used in +the build. By default the "main" branch/tag will be used for each repo +but the listed CMake argument can be used to override. + +* triton-inference-server/common: -DTRITON_COMMON_REPO_TAG=[tag] +* triton-inference-server/core: -DTRITON_CORE_REPO_TAG=[tag] + +See the [CMakeLists.txt](CMakeLists.txt) file for other build options. diff --git a/3rdparty/backend-r22.12/cmake/TritonBackendConfig.cmake.in b/3rdparty/backend-r22.12/cmake/TritonBackendConfig.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..a0fdea4fe593ce4ca5310d9ea05d8ea0cf2f09fa --- /dev/null +++ b/3rdparty/backend-r22.12/cmake/TritonBackendConfig.cmake.in @@ -0,0 +1,39 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +include(CMakeFindDependencyMacro) + +get_filename_component( + TRITONBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH +) + +list(APPEND CMAKE_MODULE_PATH ${TRITONBACKEND_CMAKE_DIR}) + +if(NOT TARGET TritonBackend::triton-backend-utils) + include("${TRITONBACKEND_CMAKE_DIR}/TritonBackendTargets.cmake") +endif() + +set(TRITONBACKEND_LIBRARIES TritonBackend::triton-backend-utils) diff --git a/3rdparty/backend-r22.12/docs/backend_platform_support_matrix.md b/3rdparty/backend-r22.12/docs/backend_platform_support_matrix.md new file mode 100644 index 0000000000000000000000000000000000000000..58341c172d0d054d59c8c6c8d9e8f4c6cacc3c54 --- /dev/null +++ b/3rdparty/backend-r22.12/docs/backend_platform_support_matrix.md @@ -0,0 +1,99 @@ + + +# Backend-Platform Support Matrix + +Even though Triton supports inference across various platforms such as +cloud, data center, edge and embedded devices on NVIDIA GPUs, x86 and +ARM CPU, or AWS Inferentia, it does so by relying on the backends. +Note that not all Triton backends support every platform. The purpose +of this document is to go over what all compute platforms are supported +by each of these Triton backends. +GPU in this document refers to Nvidia GPU. See +[GPU, Driver, and CUDA Support Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html) +to learn more about supported GPUs. + +## Ubuntu 20.04 + +The table below describes target device(s) supported for inference by +each backend on different platforms. + +| Backend | x86 | ARM-SBSA | +| ------------ | --------- | ------------- | +| TensorRT | :heavy_check_mark: GPU
:x: CPU | :heavy_check_mark: GPU
:x: CPU | +| ONNX Runtime | :heavy_check_mark: GPU
:heavy_check_mark: CPU | :heavy_check_mark: GPU
:heavy_check_mark: CPU | +| TensorFlow | :heavy_check_mark: GPU
:heavy_check_mark: CPU | :heavy_check_mark: GPU
:heavy_check_mark: CPU | +| PyTorch | :heavy_check_mark: GPU
:heavy_check_mark: CPU | :heavy_check_mark: GPU
:heavy_check_mark: CPU | +| OpenVINO | :x: GPU
:heavy_check_mark: CPU | :x: GPU
:x: CPU | +| Python[^1] | :heavy_check_mark: GPU
:heavy_check_mark: CPU | :heavy_check_mark: GPU
:heavy_check_mark: CPU | +| DALI | :heavy_check_mark: GPU
:heavy_check_mark: CPU | :heavy_check_mark: GPU[^2]
:heavy_check_mark: CPU[^2] | +| FIL | :heavy_check_mark: GPU
:heavy_check_mark: CPU | Unsupported | + + + +## Windows 10 + +Only TensorRT and ONNX Runtime backends are supported on Windows. + +| Backend | x86 | ARM-SBSA | +| ------------ | --------- | ------------- | +| TensorRT | :heavy_check_mark: GPU
:x: CPU | :heavy_check_mark: GPU
:x: CPU | +| ONNX Runtime | :heavy_check_mark: GPU
:heavy_check_mark: CPU | :heavy_check_mark: GPU
:heavy_check_mark: CPU | + + +## Jetson JetPack + +Following backends are currently supported on Jetson Jetpack: + +| Backend | Jetson | +| ------------ | --------- | +| TensorRT | :heavy_check_mark: GPU
:x: CPU | +| ONNX Runtime | :heavy_check_mark: GPU
:heavy_check_mark: CPU | :heavy_check_mark: GPU
:heavy_check_mark: CPU | +| TensorFlow | :heavy_check_mark: GPU
:heavy_check_mark: CPU | :heavy_check_mark: GPU
:heavy_check_mark: CPU | +| PyTorch | :heavy_check_mark: GPU
:heavy_check_mark: CPU | :heavy_check_mark: GPU
:heavy_check_mark: CPU | +| Python[^1] | :x: GPU
:heavy_check_mark: CPU | + + +Look at the [Triton Inference Server Support for Jetson and JetPack](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/jetson.md). + + +## AWS Inferentia + +Currently, inference on AWS Inferentia is only supported via +[python backend](https://github.com/triton-inference-server/python_backend#running-with-inferentia) +where the deployed python script invokes AWS Neuron SDK. + + +[^1]: The supported devices for Python Backend are mentioned with +respect to Triton. The python script running in Python Backend can +be used to execute inference on any hardware if there are available +python APIs to do so. AWS inferentia is one such example. Triton +core is largely unaware of the fact that inference will run on +Inferentia. + +[^2]: In case of ARM-SBSA, some operations are not fully supported. diff --git a/3rdparty/backend-r22.12/examples/README.md b/3rdparty/backend-r22.12/examples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e0ddc4cb70fe2b723daa1529ce97ce3f0aa7e7b3 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/README.md @@ -0,0 +1,460 @@ + + +[![License](https://img.shields.io/badge/License-BSD3-lightgrey.svg)](https://opensource.org/licenses/BSD-3-Clause) + +# Triton Example Backends + +To learn how to create a Triton backend, and to see a best-practices +baseline onto which you can add your own backend log, follow the +[Tutorial](#tutorial). + +Triton also provides a couple of example backends that demonstrate +specific aspects of the backend API not covered by the +[Tutorial](#tutorial). + +* The +[*repeat*](https://github.com/triton-inference-server/repeat_backend) +backend shows a more advanced example of how a backend can produce +multiple responses per request. + +* The +[*stateful*](https://github.com/triton-inference-server/stateful_backend) +backend shows an example of how a backend can manage model state +tensors on the server-side for the [sequence +batcher](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#sequence-batcher) +to avoid transferring state tensors between client and server. Triton +also implements [Implicit State +Management](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/architecture.md#implicit-state-management) +which allows backends to behave in a stateless manner and leave the +state management to Triton. + +## Tutorial + +The [Triton Backend API](../README.md#triton-backend-api) exposes a +large number of features. The backend utilities and classes provide +many functions commonly used when creating a backend. But to create a +functional backend it is not necessary to use most of the backend API +or utilities. The tutorial starts with an implementation that shows a +*minimal* backend and then adds on recommended and optional +enhancements. The tutorial implementations follow best practices for +Triton backends and so can be used as templates for your own backend. + +### *Minimal* Triton Backend + +The source code for the *minimal* backend is contained in +[minimal.cc](backends/minimal/src/minimal.cc). The source code +contains extensive documentation describing the operation of the +backend and the use of the [Triton Backend +API](../README.md#triton-backend-api) and the backend +utilities. Before reading the source code, make sure you understand +the concepts associated with Triton backend abstractions +[TRITONBACKEND_Backend](../README.md#tritonbackend_backend), +[TRITONBACKEND_Model](../README.md#tritonbackend_model), and +[TRITONBACKEND_ModelInstance](../README.md#tritonbackend_modelinstance). + +The *minimal* backend does not do any interesting operation, it simply +copies a single input tensor to a single output tensor, but it does +demonstrate the basic organization required for a Triton backend. + +The *minimal* backend is complete but for clarity leaves out some +important aspects of writing a full-featured backend that are +described in [*Recommended* Triton +Backend](#recommended-triton-backend). When creating your own backend +use the [*Recommended* Triton Backend](#recommended-triton-backend) as +a starting point. + +#### Building the *Minimal* Backend + +[backends/minimal/CMakeLists.txt](backends/minimal/CMakeLists.txt) +shows the recommended build and install script for a Triton +backend. To build the *minimal* backend and install in a local directory +use the following commands. + +``` +$ cd backends/minimal +$ mkdir build +$ cd build +$ cmake -DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install .. +$ make install +``` + +The following required Triton repositories will be pulled and used in +the build. By default the "main" branch/tag will be used for each repo +but the listed CMake argument can be used to override. + +* triton-inference-server/backend: -DTRITON_BACKEND_REPO_TAG=[tag] +* triton-inference-server/core: -DTRITON_CORE_REPO_TAG=[tag] +* triton-inference-server/common: -DTRITON_COMMON_REPO_TAG=[tag] + +If you are building on a release branch (or on a development branch +that is based off of a release branch), then you must set these cmake +arguments to point to that release branch as well. For example, if you +are building the r21.10 identity_backend branch then you need to use +the following additional cmake flags: + +``` +-DTRITON_BACKEND_REPO_TAG=r21.10 +-DTRITON_CORE_REPO_TAG=r21.10 +-DTRITON_COMMON_REPO_TAG=r21.10 +``` + +After building the install directory will contain a backends/minimal +directory that contains the *minimal* backend. Instructions for adding +this backend to the Triton server are described in [Backend Shared +Library](../README.md#backend-shared-library). + +#### Running Triton with the *Minimal* Backend + +After adding the *minimal* backend to the Triton server as described +in [Backend Shared Library](../README.md#backend-shared-library), you +can run Triton and have it load the models in +[model_repos/minimal_models](model_repos/minimal_models). Assuming you +have created a *tritonserver* Docker image by adding the *minimal* +backend to Triton, the following command will run Triton: + +``` +$ docker run --rm -it --net=host -v/path/to/model_repos/minimal_models:/models tritonserver --model-repository=/models +``` + +The console output will show similar to the following indicating that +the *batching* and *nonbatching* models from the minimal_models +repository have loaded correctly. Note that the model repository has +two models that both use the *minimal* backend. A backend can support +any number of diffent models. + +``` +I1215 23:46:00.250284 68 server.cc:589] ++-------------+---------+--------+ +| Model | Version | Status | ++-------------+---------+--------+ +| batching | 1 | READY | +| nonbatching | 1 | READY | ++-------------+---------+--------+ +``` + +The models are identical except that the *batching* model enabled the +[dynamic +batcher](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#dynamic-batcher) +and supports batch sizes up to 8. Note that the *batching* model sets +the [batch +delay](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching) +to 5 seconds so that the example client described below can +demonstrate how the *minimal* backend receives a batch of requests. + +#### Testing the *Minimal* Backend + +The [clients](clients) directory holds example clients. The +[minimal_client](clients/minimal_client) Python script demonstrates +sending a couple of inference requests to the *minimal* backend. With +Triton running as described in [Running Triton with the *Minimal* +Backend](#running-triton-with-the-minimal-backend), execute the +client: + +``` +$ clients/minimal_client +``` + +The minimal_client first sends a single request to nonbatching +model. From the output you can see that the input value is returned in +the output. + +``` +========= +Sending request to nonbatching model: IN0 = [1 2 3 4] +Response: {'model_name': 'nonbatching', 'model_version': '1', 'outputs': [{'name': 'OUT0', 'datatype': 'INT32', 'shape': [4], 'parameters': {'binary_data_size': 16}}]} +OUT0 = [1 2 3 4] +``` + +In the Triton console output you can see the log message printed by +the *minimal* backend that indicates that it received a batch +containing the single request. + +``` +I1221 18:14:12.964836 86 minimal.cc:348] model nonbatching: requests in batch 1 +I1221 18:14:12.964857 86 minimal.cc:356] batched IN0 value: [ 1, 2, 3, 4 ] +``` + +The minimal_client next sends 2 requests at the same time to the +batching model. Triton will dynamically batch those requests into a +single batch and send that single batch to the *minimal* backend. + +``` +========= +Sending request to batching model: IN0 = [[10 11 12 13]] +Sending request to batching model: IN0 = [[20 21 22 23]] +Response: {'model_name': 'batching', 'model_version': '1', 'outputs': [{'name': 'OUT0', 'datatype': 'INT32', 'shape': [1, 4], 'parameters': {'binary_data_size': 16}}]} +OUT0 = [[10 11 12 13]] +Response: {'model_name': 'batching', 'model_version': '1', 'outputs': [{'name': 'OUT0', 'datatype': 'INT32', 'shape': [1, 4], 'parameters': {'binary_data_size': 16}}]} +OUT0 = [[20 21 22 23]] +``` + +In the Triton console output you can see the log message indicating +that the *minimal* backend received a batch containing both requests. + +``` +I1221 18:14:17.965982 86 minimal.cc:348] model batching: requests in batch 2 +I1221 18:14:17.966035 86 minimal.cc:356] batched IN0 value: [ 10, 11, 12, 13, 20, 21, 22, 23 ] +``` + +### *Recommended* Triton Backend + +The source code for the *recommended* backend is contained in +[recommended.cc](backends/recommended/src/recommended.cc). The source +code contains extensive documentation describing the operation of the +backend and the use of the [Triton Backend +API](../README.md#triton-backend-api) and the backend +utilities. Before reading the source code, make sure you understand +the concepts associated with Triton backend abstractions +[TRITONBACKEND_Backend](../README.md#tritonbackend_backend), +[TRITONBACKEND_Model](../README.md#tritonbackend_model), and +[TRITONBACKEND_ModelInstance](../README.md#tritonbackend_modelinstance). + +The *recommended* backend improves the [*minimal* +backend](#minimal-triton-backend) to include the following features +which should be present in any robust backend implementation: + +* Enhances the backend to support models with input/output tensors + that have datatypes other than INT32. + +* Enhances the backend to support models with input/output tensors + that have any shape. + +* Uses the Triton backend metric APIs to record statistics about + requests executing in the backend. These metrics can then we queried + using the Triton + [metrics](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md) + and + [statistics](https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_statistics.md) + APIs. + +* Additional error checking to ensure that the backend's version is + compatible with Triton and that each model's configuration is + compatible with the backend. + +As with the *minimal* backend, the *recommended* backend just returns +the input tensor value in the output tensor. Because of the additions +described above, the *recommended* backend can serve as a starting +point for your backend. + +#### Building the *Recommended* Backend + +[backends/recommended/CMakeLists.txt](backends/recommended/CMakeLists.txt) +shows the recommended build and install script for a Triton +backend. Building and installing is the same as decribed in [Building +the *Minimal* Backend](#building-the-minimal-backend). + +#### Running Triton with the *Recommended* Backend + +After adding the *recommended* backend to the Triton server as +described in [Backend Shared +Library](../README.md#backend-shared-library), you can run Triton and +have it load the models in +[model_repos/recommended_models](model_repos/recommended_models). Assuming +you have created a *tritonserver* Docker image by adding the +*recommended* backend to Triton, the following command will run +Triton: + +``` +$ docker run --rm -it --net=host -v/path/to/model_repos/recommended_models:/models tritonserver --model-repository=/models +``` + +The console output will show similar to the following indicating that +the *batching* model from the recommended_models repository have +loaded correctly. + +``` +I1215 23:46:00.250284 68 server.cc:589] ++-------------+---------+--------+ +| Model | Version | Status | ++-------------+---------+--------+ +| batching | 1 | READY | ++-------------+---------+--------+ +``` + +#### Testing the *Recommended* Backend + +The [clients](clients) directory holds example clients. The +[recommended_client](clients/recommended_client) Python script +demonstrates sending a couple of inference requests to the +*recommended* backend. With Triton running as described in [Running +Triton with the *Recommended* +Backend](#running-triton-with-the-recommended-backend), execute the +client: + +``` +$ clients/recommended_client +``` + +The recommended_client next sends 2 requests at the same time to the +batching model, similar to what was done above with the *minimal* +backend. Triton will dynamically batch those requests into a single +batch and send that single batch to the *recommended* backend. In this +model, batching is supported, the datatype is FP32 and the tensor +shape is [ -1, 4, 4 ]. + +``` +========= +Sending request to batching model: input = [[[1. 1.1 1.2 1.3] + [2. 2.1 2.2 2.3] + [3. 3.1 3.2 3.3] + [4. 4.1 4.2 4.3]]] +Sending request to batching model: input = [[[10. 10.1 10.2 10.3] + [20. 20.1 20.2 20.3] + [30. 30.1 30.2 30.3] + [40. 40.1 40.2 40.3]]] +Response: {'model_name': 'batching', 'model_version': '1', 'outputs': [{'name': 'OUTPUT', 'datatype': 'FP32', 'shape': [1, 4, 4], 'parameters': {'binary_data_size': 64}}]} +OUTPUT = [[[1. 1.1 1.2 1.3] + [2. 2.1 2.2 2.3] + [3. 3.1 3.2 3.3] + [4. 4.1 4.2 4.3]]] +Response: {'model_name': 'batching', 'model_version': '1', 'outputs': [{'name': 'OUTPUT', 'datatype': 'FP32', 'shape': [1, 4, 4], 'parameters': {'binary_data_size': 64}}]} +OUTPUT = [[[10. 10.1 10.2 10.3] + [20. 20.1 20.2 20.3] + [30. 30.1 30.2 30.3] + [40. 40.1 40.2 40.3]]] +``` + +In the Triton console output you can see the log message indicating +that the *recommended* backend received a batch containing both +requests. + +``` +I1221 18:30:52.223226 127 recommended.cc:604] model batching: requests in batch 2 +I1221 18:30:52.223313 127 recommended.cc:613] batched INPUT value: [ 1.000000, 1.100000, 1.200000, 1.300000, 2.000000, 2.100000, 2.200000, 2.300000, 3.000000, 3.100000, 3.200000, 3.300000, 4.000000, 4.100000, 4.200000, 4.300000, 10.000000, 10.100000, 10.200000, 10.300000, 20.000000, 20.100000, 20.200001, 20.299999, 30.000000, 30.100000, 30.200001, 30.299999, 40.000000, 40.099998, 40.200001, 40.299999 ] +``` + +Because the *recommended* backend can support models that have +input/output tensors with any datatype and shape, you can edit the +model configuration and the client to experiment with these options. + +To see the metrics collected for these two inference requests, use the following command to access Triton's metrics endpoint. + +``` +$ curl localhost:8002/metrics +``` + +The output will be metric values in Prometheus data format. The +[metrics +documentation](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md) +gives a description of these metric values. + +``` +# HELP nv_inference_request_success Number of successful inference requests, all batch sizes +# TYPE nv_inference_request_success counter +nv_inference_request_success{model="batching",version="1"} 2.000000 +# HELP nv_inference_request_failure Number of failed inference requests, all batch sizes +# TYPE nv_inference_request_failure counter +nv_inference_request_failure{model="batching",version="1"} 0.000000 +# HELP nv_inference_count Number of inferences performed +# TYPE nv_inference_count counter +nv_inference_count{model="batching",version="1"} 2.000000 +# HELP nv_inference_exec_count Number of model executions performed +# TYPE nv_inference_exec_count counter +nv_inference_exec_count{model="batching",version="1"} 1.000000 +... +``` + +You can also see the collected statistics using the [statistics +endpoint](https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_statistics.md). + +``` +$ curl localhost:8000/v2/models/batching/stats +{"model_stats":[{"name":"batching","version":"1","last_inference":1640111452223,"inference_count":2,"execution_count":1,"inference_stats":{"success":{"count":2,"ns":9997025869},"fail":{"count":0,"ns":0},"queue":{"count":2,"ns":9996491319},"compute_input":{"count":2,"ns":95288},"compute_infer":{"count":2,"ns":232202},"compute_output":{"count":2,"ns":195850}},"batch_stats":[{"batch_size":2,"compute_input":{"count":1,"ns":47644},"compute_infer":{"count":1,"ns":116101},"compute_output":{"count":1,"ns":97925}}]}]} +``` + +### *BLS* Triton Backend + +Please see the [doucumentation](backends/bls/README.md) of *BLS* Backend. + +### Enhancements + +This section describes several optional features that you can add to +enhance the capabilities of your backend. + +#### Automatically Model Configuration Generation + +[Automatic model configuration +generation](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#auto-generated-model-configuration) +is enabled by the backend implementing the appropriate logic (for +example, in a function called AutoCompleteConfig) during +TRITONBACKEND_ModelInitialize. For the *recommended* backend you would +add a call to AutoCompleteConfig in the ModelState constructor just +before the call to ValidateModelConfig. The AutoCompleteConfig +function can update the model configuration with input tensor, output +tensor, and max-batch-size configuration; and then update the +configuration using TRITONBACKEND_ModelSetConfig. Examples can be +found in [ONNXRuntime +backend](https://github.com/triton-inference-server/onnxruntime_backend), +[TensorFlow +backend](https://github.com/triton-inference-server/tensorflow_backend) +and other backends. + +#### Add Key-Value Parameters to a Response + +A backend can add a key-value pair to a response any time after the +response is created and before it is sent. The parameter key must be a +string and the parameter value can be a string, integer or +boolean. The following example shows the TRITONBACKEND API used to set +response parameters. Error checking code is not shown to improve +clarity. + +``` +TRITONBACKEND_ResponseSetStringParameter(response, "param0", "an example string parameter"); +TRITONBACKEND_ResponseSetIntParameter(responses[r], "param1", 42); +TRITONBACKEND_ResponseSetBoolParameter(responses[r], "param2", false); +``` + +#### Access Model Artifacts in the Model Repository + +A backend can access any of the files in a model's area of the model +registry. These files are typically needed during +TRITONBACKEND_ModelInitialize but can be accessed at other times as +well. The TRITONBACKEND_ModelRepository API gives the location of the +model's repository. For example, the following code can be run during +TRITONBACKEND_ModelInitialize to write the location to the log. + +``` +// Can get location of the model artifacts. Normally we would need +// to check the artifact type to make sure it was something we can +// handle... but we are just going to log the location so we don't +// need the check. We would use the location if we wanted to load +// something from the model's repo. +TRITONBACKEND_ArtifactType artifact_type; +const char* clocation; +RETURN_IF_ERROR( + TRITONBACKEND_ModelRepository(model, &artifact_type, &clocation)); +LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("Repository location: ") + clocation).c_str()); +``` + +The framework backends (for example, TensorRT, ONNXRuntime, +TensorFlow, PyTorch) read the actual model file from the model +repository using this API. See those backends for examples of how it +can be used. diff --git a/3rdparty/backend-r22.12/examples/backends/bls/README.md b/3rdparty/backend-r22.12/examples/backends/bls/README.md new file mode 100644 index 0000000000000000000000000000000000000000..eae8390cad8bfdb0e4876ae460b30bedc5c0cfd1 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/backends/bls/README.md @@ -0,0 +1,135 @@ + + +# *BLS* Triton Backend + +The [*BLS*](../bls) backend demonstrates using in-process C-API to +execute inferences within the backend. This backend serves as an example to +backend developers for implementing their own custom pipeline in C++. +For Python use cases, please refer to +[Business Logic Scripting](https://github.com/triton-inference-server/python_backend/blob/main/README.md#business-logic-scripting) +section in Python backend. + +The source code for the *bls* backend is contained in +[src](./src). + +* [backend.cc](./src/backend.cc) contains the main backend +implementation. The content of this file is not BLS specific. It only includes +the required Triton backend functions that is standard for any backend +implementation. The BLS logic is set off in the +`TRITONBACKEND_ModelInstanceExecute` with lines `bls_executor.Execute(requests[r], &responses[r]);`. + +* [bls.h](./src/bls.h) is where the BLS (class `BLSExecutor`) of +this example is located. You can refer to this file to see how to interact with +Triton in-process C-API to build the custom execution pipeline. + +* [bls_utils.h](./src/bls_utils.h) is where all the utilities that +are not BLS dependent are located. + +The source code contains extensive documentation describing the operation of +the backend and the use of the +[Triton Backend API](../../../README.md#triton-backend-api) and the +[Triton Server API](https://github.com/triton-inference-server/server/blob/main/docs/customization_guide/inference_protocols.md#in-process-triton-server-api). +Before reading the source code, make sure you understand +the concepts associated with Triton backend abstractions +[TRITONBACKEND_Backend](../../../README.md#tritonbackend_backend), +[TRITONBACKEND_Model](../../../README.md#tritonbackend_model), and +[TRITONBACKEND_ModelInstance](../../../README.md#tritonbackend_modelinstance). + +The *bls* backend will send two requests on the 'addsub_python' and 'addsub_tf' +models. After the inference requests are completed, this backend will extract +OUTPUT0 from the 'addsub_python' and OUTPUT1 from the 'addsub_tf' model to +construct the final inference response object using these tensors. + +There are some self-imposed limitations that were made for the simplicity of +this example: +1. This backend does not support batching. +2. This backend does not support decoupled models. +3. This backend does not support GPU tensors. +4. The model configuraion should be strictly set as the comments described in +[backend.cc](./src/backend.cc). + +You can implement your custom backend that is not limited to the limitations +mentioned above. + +## Building the *BLS* Backend + +[backends/bls/CMakeLists.txt](CMakeLists.txt) +shows the recommended build and install script for a Triton +backend. Building and installing is the same as decribed in [Building +the *Minimal* Backend](../../README.md#building-the-minimal-backend). + +## Running Triton with the *BLS* Backend + +After adding the *bls* backend to the Triton server as +described in [Backend Shared +Library](../../../README.md#backend-shared-library), you can run Triton and +have it load the models in +[model_repos/bls_models](../../model_repos/bls_models). Assuming you have created a +*tritonserver* Docker image by adding the *bls* backend to Triton, the +following command will run Triton: + +``` +$ docker run --rm -it --net=host -v/path/to/model_repos/bls_models:/models tritonserver --model-repository=/models +``` + +The console output will show similar to the following indicating that +the *bls_fp32*, *addsub_python* and *addsub_tf* models from the bls_models repository have +loaded correctly. + +``` +I0616 09:34:47.767433 19214 server.cc:629] ++---------------+---------+--------+ +| Model | Version | Status | ++---------------+---------+--------+ +| addsub_python | 1 | READY | +| addsub_tf | 1 | READY | +| bls_fp32 | 1 | READY | ++---------------+---------+--------+ +``` + +## Testing the *BLS* Backend + +The [clients](../../clients) directory holds example clients. The +[bls_client](../../clients/bls_client) Python script demonstrates sending an +inference requests to the *bls* backend. With Triton running as +described in [Running Triton with the *BLS* Backend](#running-triton-with-the-bls-backend), +execute the client: + +``` +$ clients/bls_client +``` + +You should see an output similar to the output below: + +``` +INPUT0 ([0.42935285 0.51512766 0.43625894 ... 0.6670954 0.17747518 0.7976901 ]) + INPUT1 ([6.7752063e-01 2.4223252e-01 6.7743927e-01 ... 4.1531715e-01 2.5451833e-01 7.9097062e-01]) = OUTPUT0 ([1.1068735 0.75736016 1.1136982 ... 1.0824126 0.4319935 1.5886607 ]) +INPUT0 ([0.42935285 0.51512766 0.43625894 ... 0.6670954 0.17747518 0.7976901 ]) - INPUT1 ([6.7752063e-01 2.4223252e-01 6.7743927e-01 ... 4.1531715e-01 2.5451833e-01 7.9097062e-01]) = OUTPUT1 ([-0.24816778 0.27289516 -0.24118033 ... 0.25177827 -0.07704315 0.00671947]) + +PASS +``` diff --git a/3rdparty/backend-r22.12/examples/backends/bls/cmake/TritonBLSBackendConfig.cmake.in b/3rdparty/backend-r22.12/examples/backends/bls/cmake/TritonBLSBackendConfig.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..dd41ae7aeb9cb34277f7b49bf15e00d5fd1fc007 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/backends/bls/cmake/TritonBLSBackendConfig.cmake.in @@ -0,0 +1,39 @@ +# Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +include(CMakeFindDependencyMacro) + +get_filename_component( + TRITONBLSBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH +) + +list(APPEND CMAKE_MODULE_PATH ${TRITONBLSBACKEND_CMAKE_DIR}) + +if(NOT TARGET TritonBLSBackend::triton-bls-backend) + include("${TRITONBLSBACKEND_CMAKE_DIR}/TritonBLSBackendTargets.cmake") +endif() + +set(TRITONBLSBACKEND_LIBRARIES TritonBLSBackend::triton-bls-backend) diff --git a/3rdparty/backend-r22.12/examples/backends/bls/src/backend.cc b/3rdparty/backend-r22.12/examples/backends/bls/src/backend.cc new file mode 100644 index 0000000000000000000000000000000000000000..66f1c17a13506f53a96294a0ff0f11aead5d7d95 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/backends/bls/src/backend.cc @@ -0,0 +1,526 @@ +// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "bls.h" +#include "triton/backend/backend_model.h" +#include "triton/backend/backend_model_instance.h" + +// +// Backend that demonstrates using in-process C-API to execute inferences +// within the backend. +// +// Two particular models, 'addsub_python' and 'addsub_tf', must be loaded on +// the server for a successful inference execution on this backend. +// +// The model configuration should be set as follows in order to be in line with +// the 'addsub_python' and 'addsub_tf' models. This backend does not support +// batching. These limitations are only for this specific backend. You can +// implement your custom BLS backend with less limitations. +// +// Model Configuration: +// - Input 'INPUT0' must have shape [16] and datatype must be TYPE_FP32. +// +// - Input 'INPUT1' must have shape [16] and datatype must be TYPE_FP32. +// +// - For each response, output 'OUTPUT0' must have shape [16] and +// datatype TYPE_FP32. +// +// - For each response, output 'OUTPUT1' must have shape [16] and +// datatype TYPE_FP32. +// +// This backend will send two requests on the 'addsub_python' and 'addsub_tf' +// models. After the inference requests are completed, this backend +// will extract OUTPUT0 from the 'addsub_python' and OUTPUT1 from the +// 'addsub_tf' model to construct the final inference response object using +// these tensors. + +namespace triton { namespace backend { namespace bls { + +// +// ModelState +// +// State associated with a model that is using this backend. An object +// of this class is created and associated with each +// TRITONBACKEND_Model. +// +class ModelState : public BackendModel { + public: + static TRITONSERVER_Error* Create( + TRITONBACKEND_Model* triton_model, ModelState** state); + virtual ~ModelState() = default; + + // Validate that model configuration is supported by this backend. + TRITONSERVER_Error* ValidateModelConfig(); + + private: + ModelState(TRITONBACKEND_Model* triton_model) : BackendModel(triton_model) {} +}; + +TRITONSERVER_Error* +ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) +{ + try { + *state = new ModelState(triton_model); + } + catch (const BackendModelException& ex) { + RETURN_ERROR_IF_TRUE( + ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelException")); + RETURN_IF_ERROR(ex.err_); + } + + return nullptr; // success +} + +TRITONSERVER_Error* +ModelState::ValidateModelConfig() +{ + // We have the json DOM for the model configuration... + common::TritonJson::WriteBuffer buffer; + RETURN_IF_ERROR(model_config_.PrettyWrite(&buffer)); + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("model configuration:\n") + buffer.Contents()).c_str()); + + // max_batch_size must be 0 because this backend does not support + // batching + int64_t max_batch_size; + RETURN_IF_ERROR(model_config_.MemberAsInt("max_batch_size", &max_batch_size)); + RETURN_ERROR_IF_FALSE( + max_batch_size == 0, TRITONSERVER_ERROR_INVALID_ARG, + std::string("bls backend only supports models with max_batch_size == 0")); + + common::TritonJson::Value inputs, outputs; + RETURN_IF_ERROR(model_config_.MemberAsArray("input", &inputs)); + RETURN_IF_ERROR(model_config_.MemberAsArray("output", &outputs)); + + // There must be 2 inputs and 2 outputs. + RETURN_ERROR_IF_FALSE( + inputs.ArraySize() == 2, TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected 2 inputs, got ") + + std::to_string(inputs.ArraySize())); + RETURN_ERROR_IF_FALSE( + outputs.ArraySize() == 2, TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected 2 outputs, got ") + + std::to_string(outputs.ArraySize())); + + // Here we rely on the model configuation listing the inputs and + // outputs in a specific order, which we shouldn't really require... + common::TritonJson::Value input0, input1, output0, output1; + RETURN_IF_ERROR(inputs.IndexAsObject(0, &input0)); + RETURN_IF_ERROR(inputs.IndexAsObject(1, &input1)); + RETURN_IF_ERROR(outputs.IndexAsObject(0, &output0)); + RETURN_IF_ERROR(outputs.IndexAsObject(1, &output1)); + + // Check tensor names + std::string in0_name, in1_name, out0_name, out1_name; + RETURN_IF_ERROR(input0.MemberAsString("name", &in0_name)); + RETURN_IF_ERROR(input1.MemberAsString("name", &in1_name)); + RETURN_IF_ERROR(output0.MemberAsString("name", &out0_name)); + RETURN_IF_ERROR(output1.MemberAsString("name", &out1_name)); + + RETURN_ERROR_IF_FALSE( + in0_name == "INPUT0", TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected first input tensor name to be INPUT0, got ") + + in0_name); + RETURN_ERROR_IF_FALSE( + in1_name == "INPUT1", TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected second input tensor name to be INPUT1, got ") + + in1_name); + RETURN_ERROR_IF_FALSE( + out0_name == "OUTPUT0", TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected first output tensor name to be OUTPUT0, got ") + + out0_name); + RETURN_ERROR_IF_FALSE( + out1_name == "OUTPUT1", TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected second output tensor name to be OUTPUT1, got ") + + out1_name); + + // Check shapes + std::vector in0_shape, in1_shape, out0_shape, out1_shape; + RETURN_IF_ERROR(backend::ParseShape(input0, "dims", &in0_shape)); + RETURN_IF_ERROR(backend::ParseShape(input1, "dims", &in1_shape)); + RETURN_IF_ERROR(backend::ParseShape(output0, "dims", &out0_shape)); + RETURN_IF_ERROR(backend::ParseShape(output1, "dims", &out1_shape)); + + RETURN_ERROR_IF_FALSE( + in0_shape.size() == 1, TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected INPUT0 shape to have one dimension, got ") + + backend::ShapeToString(in0_shape)); + RETURN_ERROR_IF_FALSE( + in1_shape.size() == 1, TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected INPUT1 shape to have one dimension, got ") + + backend::ShapeToString(in1_shape)); + RETURN_ERROR_IF_FALSE( + out0_shape.size() == 1, TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected OUTPUT0 shape to have one dimension, got ") + + backend::ShapeToString(out0_shape)); + RETURN_ERROR_IF_FALSE( + out1_shape.size() == 1, TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected OUTPUT1 shape to have one dimension, got ") + + backend::ShapeToString(out1_shape)); + + // Check datatypes + std::string in0_dtype, in1_dtype, out0_dtype, out1_dtype; + RETURN_IF_ERROR(input0.MemberAsString("data_type", &in0_dtype)); + RETURN_IF_ERROR(input1.MemberAsString("data_type", &in1_dtype)); + RETURN_IF_ERROR(output0.MemberAsString("data_type", &out0_dtype)); + RETURN_IF_ERROR(output1.MemberAsString("data_type", &out1_dtype)); + + RETURN_ERROR_IF_FALSE( + in0_dtype == "TYPE_FP32", TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected INPUT0 datatype to be TYPE_FP32, got ") + + in0_dtype); + RETURN_ERROR_IF_FALSE( + in1_dtype == "TYPE_FP32", TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected INPUT1 datatype to be TYPE_FP32, got ") + + in1_dtype); + RETURN_ERROR_IF_FALSE( + out0_dtype == "TYPE_FP32", TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected OUTPUT0 datatype to be TYPE_FP32, got ") + + out0_dtype); + RETURN_ERROR_IF_FALSE( + out1_dtype == "TYPE_FP32", TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected OUTPUT1 datatype to be TYPE_FP32, got ") + + out1_dtype); + + return nullptr; // success +} + +// +// ModelInstanceState +// +// State associated with a model instance. An object of this class is +// created and associated with each TRITONBACKEND_ModelInstance. +// +class ModelInstanceState : public BackendModelInstance { + public: + static TRITONSERVER_Error* Create( + ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state); + virtual ~ModelInstanceState() = default; + + void ProcessRequests( + TRITONBACKEND_Request** requests, const uint32_t request_count); + + private: + ModelInstanceState( + ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance) + : BackendModelInstance(model_state, triton_model_instance) + { + } +}; + +TRITONSERVER_Error* +ModelInstanceState::Create( + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state) +{ + try { + *state = new ModelInstanceState(model_state, triton_model_instance); + } + catch (const BackendModelInstanceException& ex) { + RETURN_ERROR_IF_TRUE( + ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelInstanceException")); + RETURN_IF_ERROR(ex.err_); + } + + return nullptr; // success +} + +void +ModelInstanceState::ProcessRequests( + TRITONBACKEND_Request** requests, const uint32_t request_count) +{ + uint64_t exec_start_ns = 0; + SET_TIMESTAMP(exec_start_ns); + + for (size_t i = 0; i < request_count; i++) { + // If we get a nullptr request then something is badly wrong. Fail + // and release all requests. + if (requests[i] == nullptr) { + RequestsRespondWithError( + requests, request_count, + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string( + "null request given to BLS backend for '" + Name() + "'") + .c_str())); + return; + } + } + + // At this point we accept ownership of 'requests', which means that + // even if something goes wrong we must still return success from + // this function. If something does go wrong in processing a + // particular request then we send an error response just for the + // specific request. + std::vector responses; + responses.reserve(request_count); + + for (size_t i = 0; i < request_count; i++) { + TRITONBACKEND_Response* response; + auto err = TRITONBACKEND_ResponseNew(&response, requests[i]); + if (err == nullptr) { + responses.emplace_back(response); + } else { + responses.emplace_back(nullptr); + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Fail to create response"); + TRITONSERVER_ErrorDelete(err); + } + } + + ModelState* model_state = reinterpret_cast(Model()); + + // The way we collect these batch timestamps is not entirely + // accurate. Normally, in a performant backend you would execute all + // the requests at the same time, and so there would be a single + // compute-start / compute-end time-range. But here we execute each + // request separately so there is no single range. As a result we + // just show the entire execute time as being the compute time as + // well. + uint64_t compute_start_ns = 0; + SET_TIMESTAMP(compute_start_ns); + + // Create a BLSExecutor object. To separate from standard backend + // implementation, the BLS logic is placed inside class BLSExecutor. + BLSExecutor bls_executor(model_state->TritonServer()); + + for (size_t r = 0; r < request_count; r++) { + bls_executor.Execute(requests[r], &responses[r]); + } + + uint64_t compute_end_ns = 0; + SET_TIMESTAMP(compute_end_ns); + + uint64_t exec_end_ns = 0; + SET_TIMESTAMP(exec_end_ns); + + // Send all the responses that haven't already been sent because of + // an earlier error. Note that the responses are not set to nullptr + // here as we need that indication below to determine if the request + // we successful or not. + for (auto& response : responses) { + if (response != nullptr) { + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend( + response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr), + "failed to send BLS backend response"); + } + } + + // Report statistics for each request. + for (uint32_t r = 0; r < request_count; ++r) { + auto& request = requests[r]; + LOG_IF_ERROR( + TRITONBACKEND_ModelInstanceReportStatistics( + TritonModelInstance(), request, + (responses[r] != nullptr) /* success */, exec_start_ns, + compute_start_ns, compute_end_ns, exec_end_ns), + "failed reporting request statistics"); + + LOG_IF_ERROR( + TRITONBACKEND_RequestRelease(request, TRITONSERVER_REQUEST_RELEASE_ALL), + "failed releasing request"); + } + + // Report the entire batch statistics. + LOG_IF_ERROR( + TRITONBACKEND_ModelInstanceReportBatchStatistics( + TritonModelInstance(), 1 /*total_batch_size*/, exec_start_ns, + compute_start_ns, compute_end_ns, exec_end_ns), + "failed reporting batch request statistics"); + + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("TRITONBACKEND_ModelExecute: model ") + Name() + + " released " + std::to_string(request_count) + " requests") + .c_str()); +} + +///////////// + +extern "C" { + +// Implementing TRITONBACKEND_ModelInitialize is optional. The backend +// should initialize any state that is intended to be shared across +// all instances of the model. +TRITONSERVER_Error* +TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model) +{ + const char* cname; + RETURN_IF_ERROR(TRITONBACKEND_ModelName(model, &cname)); + std::string name(cname); + + uint64_t version; + RETURN_IF_ERROR(TRITONBACKEND_ModelVersion(model, &version)); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("TRITONBACKEND_ModelInitialize: ") + name + " (version " + + std::to_string(version) + ")") + .c_str()); + + // With each model we create a ModelState object and associate it + // with the TRITONBACKEND_Model. + ModelState* model_state; + RETURN_IF_ERROR(ModelState::Create(model, &model_state)); + RETURN_IF_ERROR( + TRITONBACKEND_ModelSetState(model, reinterpret_cast(model_state))); + + // One of the primary things to do in ModelInitialize is to examine + // the model configuration to ensure that it is something that this + // backend can support. If not, returning an error from this + // function will prevent the model from loading. + RETURN_IF_ERROR(model_state->ValidateModelConfig()); + + return nullptr; // success +} + +// Implementing TRITONBACKEND_ModelFinalize is optional unless state +// is set using TRITONBACKEND_ModelSetState. The backend must free +// this state and perform any other cleanup. +TRITONSERVER_Error* +TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model* model) +{ + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vstate)); + ModelState* model_state = reinterpret_cast(vstate); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, "TRITONBACKEND_ModelFinalize: delete model state"); + + delete model_state; + + return nullptr; // success +} + +// Implementing TRITONBACKEND_ModelInstanceInitialize is optional. The +// backend should initialize any state that is required for a model +// instance. +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceInitialize(TRITONBACKEND_ModelInstance* instance) +{ + const char* cname; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceName(instance, &cname)); + std::string name(cname); + + int32_t device_id; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceDeviceId(instance, &device_id)); + TRITONSERVER_InstanceGroupKind kind; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceKind(instance, &kind)); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("TRITONBACKEND_ModelInstanceInitialize: ") + name + " (" + + TRITONSERVER_InstanceGroupKindString(kind) + " device " + + std::to_string(device_id) + ")") + .c_str()); + + // The instance can access the corresponding model as well... here + // we get the model and from that get the model's state. + TRITONBACKEND_Model* model; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model)); + + void* vmodelstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vmodelstate)); + ModelState* model_state = reinterpret_cast(vmodelstate); + + // With each instance we create a ModelInstanceState object and + // associate it with the TRITONBACKEND_ModelInstance. + ModelInstanceState* instance_state; + RETURN_IF_ERROR( + ModelInstanceState::Create(model_state, instance, &instance_state)); + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState( + instance, reinterpret_cast(instance_state))); + + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("TRITONBACKEND_ModelInstanceInitialize: instance " + "initialization successful ") + + name + " (device " + std::to_string(device_id) + ")") + .c_str()); + + return nullptr; // success +} + +// Implementing TRITONBACKEND_ModelInstanceFinalize is optional unless +// state is set using TRITONBACKEND_ModelInstanceSetState. The backend +// must free this state and perform any other cleanup. +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInstance* instance) +{ + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, &vstate)); + ModelInstanceState* instance_state = + reinterpret_cast(vstate); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + "TRITONBACKEND_ModelInstanceFinalize: delete instance state"); + + delete instance_state; + + return nullptr; // success +} + +// Implementing TRITONBACKEND_ModelInstanceExecute is required. +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceExecute( + TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests, + const uint32_t request_count) +{ + // Triton will not call this function simultaneously for the same + // 'instance'. But since this backend could be used by multiple + // instances from multiple models the implementation needs to handle + // multiple calls to this function at the same time (with different + // 'instance' objects). Suggested practice for this is to use only + // function-local and model-instance-specific state (obtained from + // 'instance'), which is what we do here. + ModelInstanceState* instance_state; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState( + instance, reinterpret_cast(&instance_state))); + ModelState* model_state = + reinterpret_cast(instance_state->Model()); + + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("model ") + model_state->Name() + ", instance " + + instance_state->Name() + ", executing " + std::to_string(request_count) + + " requests") + .c_str()); + + instance_state->ProcessRequests(requests, request_count); + + return nullptr; // success +} + +} // extern "C" + +}}} // namespace triton::backend::bls diff --git a/3rdparty/backend-r22.12/examples/backends/bls/src/bls.cc b/3rdparty/backend-r22.12/examples/backends/bls/src/bls.cc new file mode 100644 index 0000000000000000000000000000000000000000..a892a41498487c771b0e8c58254c6eaed88caf8a --- /dev/null +++ b/3rdparty/backend-r22.12/examples/backends/bls/src/bls.cc @@ -0,0 +1,291 @@ +// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "bls.h" + +namespace triton { namespace backend { namespace bls { + +BLSExecutor::BLSExecutor(TRITONSERVER_Server* server) + : server_(server), model_executor_(server) +{ +} + +TRITONSERVER_Error* +BLSExecutor::PrepareInferenceRequest( + TRITONBACKEND_Request* bls_request, + TRITONSERVER_InferenceRequest** irequest, const std::string model_name) +{ + // Get request_id, correlation_id, and flags from the current request + // for preparing a new inference request that we will send to 'addsub_python' + // or 'addsub_tf' model later. + const char* request_id; + uint64_t correlation_id; + uint32_t flags; + RETURN_IF_ERROR(TRITONBACKEND_RequestId(bls_request, &request_id)); + RETURN_IF_ERROR( + TRITONBACKEND_RequestCorrelationId(bls_request, &correlation_id)); + RETURN_IF_ERROR(TRITONBACKEND_RequestFlags(bls_request, &flags)); + + // Create an inference request object. The inference request object + // is where we set the name of the model we want to use for + // inference and the input tensors. + RETURN_IF_ERROR(TRITONSERVER_InferenceRequestNew( + irequest, server_, model_name.c_str(), -1 /* model_version */)); + // Set request_id, correlation_id, and flags for the new request. + RETURN_IF_ERROR(TRITONSERVER_InferenceRequestSetId(*irequest, request_id)); + RETURN_IF_ERROR( + TRITONSERVER_InferenceRequestSetCorrelationId(*irequest, correlation_id)); + RETURN_IF_ERROR(TRITONSERVER_InferenceRequestSetFlags(*irequest, flags)); + RETURN_IF_ERROR(TRITONSERVER_InferenceRequestSetReleaseCallback( + *irequest, InferRequestComplete, nullptr /* request_release_userp */)); + + return nullptr; // success +} + +TRITONSERVER_Error* +BLSExecutor::PrepareInferenceInput( + TRITONBACKEND_Request* bls_request, TRITONSERVER_InferenceRequest* irequest) +{ + // Get the properties of the two inputs from the current request. + // Then, add the two input tensors and append the input data to the new + // request. + uint32_t input_count; + RETURN_IF_ERROR(TRITONBACKEND_RequestInputCount(bls_request, &input_count)); + + TRITONBACKEND_Input* input; + const char* name; + TRITONSERVER_DataType datatype; + const int64_t* shape; + uint32_t dims_count; + size_t data_byte_size; + TRITONSERVER_MemoryType data_memory_type; + int64_t data_memory_id; + const char* data_buffer; + + for (size_t count = 0; count < input_count; count++) { + RETURN_IF_ERROR(TRITONBACKEND_RequestInputByIndex( + bls_request, count /* index */, &input)); + RETURN_IF_ERROR(TRITONBACKEND_InputProperties( + input, &name, &datatype, &shape, &dims_count, nullptr, nullptr)); + RETURN_IF_ERROR(TRITONBACKEND_InputBuffer( + input, 0 /* idx */, reinterpret_cast(&data_buffer), + &data_byte_size, &data_memory_type, &data_memory_id)); + RETURN_IF_ERROR(TRITONSERVER_InferenceRequestAddInput( + irequest, name, datatype, shape, dims_count)); + RETURN_IF_ERROR(TRITONSERVER_InferenceRequestAppendInputData( + irequest, name, &data_buffer[0], data_byte_size, data_memory_type, + data_memory_id)); + } + + return nullptr; // success +} + +TRITONSERVER_Error* +BLSExecutor::PrepareInferenceOutput( + TRITONBACKEND_Request* bls_request, TRITONSERVER_InferenceRequest* irequest) +{ + // Indicate the output tensors to be calculated and returned + // for the inference request. + uint32_t output_count; + RETURN_IF_ERROR(TRITONBACKEND_RequestOutputCount(bls_request, &output_count)); + const char* output_name; + for (size_t count = 0; count < output_count; count++) { + RETURN_IF_ERROR(TRITONBACKEND_RequestOutputName( + bls_request, count /* index */, &output_name)); + RETURN_IF_ERROR( + TRITONSERVER_InferenceRequestAddRequestedOutput(irequest, output_name)); + } + + return nullptr; // success +} + +void +BLSExecutor::Execute( + TRITONBACKEND_Request* bls_request, TRITONBACKEND_Response** response) +{ + // The names of the models that we will send internal requests on. + std::vector model_names = {"addsub_python", "addsub_tf"}; + + // Check if both models are valid before executing request. + try { + for (size_t i = 0; i < 2; i++) { + // Check if the model is ready. + bool is_ready = false; + THROW_IF_TRITON_ERROR(TRITONSERVER_ServerModelIsReady( + server_, model_names[i].c_str(), -1 /* model_version */, &is_ready)); + if (!is_ready) { + throw BLSBackendException( + (std::string("Failed to execute the inference request. Model '") + + model_names[i].c_str() + "' is not ready.") + .c_str()); + } + // For simplicity, decoupled API is not supported in this BLS backend. You + // can implement your own backend that supports decoupled models. + uint32_t txn_flags; + THROW_IF_TRITON_ERROR(TRITONSERVER_ServerModelTransactionProperties( + server_, model_names[i].c_str(), -1 /* model_version */, &txn_flags, + nullptr /* voidp */)); + if ((txn_flags & TRITONSERVER_TXN_DECOUPLED) != 0) { + throw BLSBackendException( + std::string("Model '") + model_names[i].c_str() + + "' is using the decoupled. This BLS Backend doesn't support models " + "using the decoupled transaction policy."); + } + } + } + catch (const BLSBackendException& bls_exception) { + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, bls_exception.what()); + RESPOND_AND_SET_NULL_IF_ERROR( + response, + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, "Failed to send inference requests")); + return; + } + + // Prepare std::future for each model. Since this BLS backend + // can handle requests in parallel, we will send all the inference + // requests first and then retrieve them later. + std::vector> futures(2); + + // The inference request object for sending internal requests. + TRITONSERVER_InferenceRequest* irequest = nullptr; + + // For each inference request, the backend sends two requests on the + // 'addsub_python' and 'addsub_tf' models. + try { + for (size_t icount = 0; icount < 2; icount++) { + // Initialize the inference request with required information. + THROW_IF_TRITON_ERROR( + PrepareInferenceRequest(bls_request, &irequest, model_names[icount])); + THROW_IF_TRITON_ERROR(PrepareInferenceInput(bls_request, irequest)); + THROW_IF_TRITON_ERROR(PrepareInferenceOutput(bls_request, irequest)); + + // Execute inference request. + THROW_IF_TRITON_ERROR( + model_executor_.AsyncExecute(irequest, &futures[icount])); + } + } + catch (const BLSBackendException& bls_exception) { + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, bls_exception.what()); + LOG_IF_ERROR( + TRITONSERVER_InferenceRequestDelete(irequest), + "Failed to delete inference request."); + RESPOND_AND_SET_NULL_IF_ERROR( + response, + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, "Failed to send inference requests")); + return; + } + + // If both internal requests are sent successfully, retrieve the output from + // each request and construct the final response. + ConstructFinalResponse(response, std::move(futures)); +} + +void +BLSExecutor::ConstructFinalResponse( + TRITONBACKEND_Response** response, + std::vector> futures) +{ + // Prepare two TRITONSERVER_InferenceResponse* objects for 'addsub_python' and + // 'addsub_tf' repectively. + std::vector completed_responses = {nullptr, + nullptr}; + + const char* output_name; + TRITONSERVER_DataType output_datatype; + const int64_t* output_shape; + uint64_t dims_count; + size_t output_byte_size; + TRITONSERVER_MemoryType output_memory_type; + int64_t output_memory_id; + const void* output_base; + void* userp; + for (size_t icount = 0; icount < 2; icount++) { + // Retrieve the corresponding TRITONSERVER_InferenceResponse object from + // 'futures'. The InferResponseComplete function sets the std::promise + // so that this thread will block until the response is returned. + completed_responses[icount] = futures[icount].get(); + try { + THROW_IF_TRITON_ERROR( + TRITONSERVER_InferenceResponseError(completed_responses[icount])); + } + catch (const BLSBackendException& bls_exception) { + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, bls_exception.what()); + + if (completed_responses[icount] != nullptr) { + LOG_IF_ERROR( + TRITONSERVER_InferenceResponseDelete(completed_responses[icount]), + "Failed to delete inference response."); + } + return; + } + // Retrieve outputs from 'completed_responses'. + // Extract OUTPUT0 from the 'addsub_python' and OUTPUT1 from the + // 'addsub_tf' model to form the final inference response object. + // Get all the information about the output tensor. + RESPOND_AND_SET_NULL_IF_ERROR( + response, + TRITONSERVER_InferenceResponseOutput( + completed_responses[icount], icount, &output_name, &output_datatype, + &output_shape, &dims_count, &output_base, &output_byte_size, + &output_memory_type, &output_memory_id, &userp)); + + // Create an output tensor in the final response with + // the information retrieved above. + TRITONBACKEND_Output* output; + RESPOND_AND_SET_NULL_IF_ERROR( + response, TRITONBACKEND_ResponseOutput( + *response, &output, output_name, output_datatype, + output_shape, dims_count)); + + // Get a buffer that holds the tensor data for the output. + // We request a buffer in CPU memory but we have to handle any returned + // type. If we get back a buffer in GPU memory we just fail the request. + void* output_buffer; + output_memory_type = TRITONSERVER_MEMORY_CPU; + RESPOND_AND_SET_NULL_IF_ERROR( + response, TRITONBACKEND_OutputBuffer( + output, &output_buffer, output_byte_size, + &output_memory_type, &output_memory_id)); + if (output_memory_type == TRITONSERVER_MEMORY_GPU) { + RESPOND_AND_SET_NULL_IF_ERROR( + response, TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "failed to create output buffer in CPU memory")); + } + + // Fill the BLS output buffer with output data returned by internal + // requests. + memcpy(output_buffer, output_base, output_byte_size); + + LOG_IF_ERROR( + TRITONSERVER_InferenceResponseDelete(completed_responses[icount]), + "Failed to delete inference response."); + } +} + +}}} // namespace triton::backend::bls diff --git a/3rdparty/backend-r22.12/examples/backends/bls/src/bls.h b/3rdparty/backend-r22.12/examples/backends/bls/src/bls.h new file mode 100644 index 0000000000000000000000000000000000000000..a0a3a1ed0d448402c04d73270034a3aa54120691 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/backends/bls/src/bls.h @@ -0,0 +1,79 @@ +// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include "bls_utils.h" +#include "triton/backend/backend_common.h" +#include "triton/core/tritonbackend.h" +#include "triton/core/tritonserver.h" + +namespace triton { namespace backend { namespace bls { + +// +// BLSExecutor +// +// Includes the custom BLS logic for this backend. +// This class shows how to utilize Triton in-process C-API to build the +// execution pipeline. +// +class BLSExecutor { + public: + BLSExecutor(TRITONSERVER_Server* server); + + // Prepares the inference request that will be used internally. + TRITONSERVER_Error* PrepareInferenceRequest( + TRITONBACKEND_Request* bls_request, + TRITONSERVER_InferenceRequest** irequest, const std::string model_name); + + // Prepares the input for the internal inference request. + TRITONSERVER_Error* PrepareInferenceInput( + TRITONBACKEND_Request* bls_request, + TRITONSERVER_InferenceRequest* irequest); + + // Prepares the output for the internal inference request. + TRITONSERVER_Error* PrepareInferenceOutput( + TRITONBACKEND_Request* bls_request, + TRITONSERVER_InferenceRequest* irequest); + + // Performs the whole BLS pipeline. + void Execute( + TRITONBACKEND_Request* bls_request, TRITONBACKEND_Response** response); + + // Constructs the final response. + void ConstructFinalResponse( + TRITONBACKEND_Response** response, + std::vector> futures); + + private: + // The server object that encapsulates all the functionality of the Triton + // server and allows access to the Triton server API. + TRITONSERVER_Server* server_; + + // The ModelExecutor object for executing inference request on a model. + ModelExecutor model_executor_; +}; + +}}} // namespace triton::backend::bls diff --git a/3rdparty/backend-r22.12/examples/backends/bls/src/bls_utils.cc b/3rdparty/backend-r22.12/examples/backends/bls/src/bls_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..d935275309fdb7dfaf168882e13ec866ac868379 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/backends/bls/src/bls_utils.cc @@ -0,0 +1,186 @@ +// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "bls_utils.h" + +namespace triton { namespace backend { namespace bls { + +TRITONSERVER_Error* +CPUAllocator( + TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name, + size_t byte_size, TRITONSERVER_MemoryType preferred_memory_type, + int64_t preferred_memory_type_id, void* userp, void** buffer, + void** buffer_userp, TRITONSERVER_MemoryType* actual_memory_type, + int64_t* actual_memory_type_id) +{ + // For simplicity, this backend example always uses CPU memory regardless of + // the preferred memory type. You can make the actual memory type and id that + // we allocate be the same as preferred memory type. You can also provide a + // customized allocator to support different preferred_memory_type, and reuse + // memory buffer when possible. + *actual_memory_type = TRITONSERVER_MEMORY_CPU; + *actual_memory_type_id = preferred_memory_type_id; + + // If 'byte_size' is zero just return 'buffer' == nullptr, we don't + // need to do any other book-keeping. + if (byte_size == 0) { + *buffer = nullptr; + *buffer_userp = nullptr; + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, ("allocated " + std::to_string(byte_size) + + " bytes for result tensor " + tensor_name) + .c_str()); + } else { + void* allocated_ptr = nullptr; + *actual_memory_type = TRITONSERVER_MEMORY_CPU; + allocated_ptr = malloc(byte_size); + + // Pass the tensor name with buffer_userp so we can show it when + // releasing the buffer. + if (allocated_ptr != nullptr) { + *buffer = allocated_ptr; + *buffer_userp = new std::string(tensor_name); + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + ("allocated " + std::to_string(byte_size) + " bytes in " + + TRITONSERVER_MemoryTypeString(*actual_memory_type) + + " for result tensor " + tensor_name) + .c_str()); + } + } + + return nullptr; // Success +} + +TRITONSERVER_Error* +ResponseRelease( + TRITONSERVER_ResponseAllocator* allocator, void* buffer, void* buffer_userp, + size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id) +{ + std::string* name = nullptr; + if (buffer_userp != nullptr) { + name = reinterpret_cast(buffer_userp); + } else { + name = new std::string(""); + } + + std::stringstream ss; + ss << buffer; + std::string buffer_str = ss.str(); + + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + ("Releasing buffer " + buffer_str + " of size " + + std::to_string(byte_size) + " in " + + TRITONSERVER_MemoryTypeString(memory_type) + " for result '" + *name) + .c_str()); + + switch (memory_type) { + case TRITONSERVER_MEMORY_CPU: + free(buffer); + break; + default: + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + std::string( + "error: unexpected buffer allocated in CUDA managed memory") + .c_str()); + break; + } + + delete name; + + return nullptr; // Success +} + +void +InferRequestComplete( + TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp) +{ + if (request != nullptr) { + LOG_IF_ERROR( + TRITONSERVER_InferenceRequestDelete(request), + "Failed to delete inference request."); + } +} + +void +InferResponseComplete( + TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp) +{ + // The following logic only works for non-decoupled models as for decoupled + // models it may send multiple responses for a request or not send any + // responses for a request. Need to modify this function if the model is using + // decoupled API. + if (response != nullptr) { + // Send 'response' to the future. + std::promise* p = + reinterpret_cast*>(userp); + p->set_value(response); + delete p; + } +} + +ModelExecutor::ModelExecutor(TRITONSERVER_Server* server) : server_(server) +{ + // When triton needs a buffer to hold an output tensor, it will ask + // us to provide the buffer. In this way we can have any buffer + // management and sharing strategy that we want. To communicate to + // triton the functions that we want it to call to perform the + // allocations, we create a "response allocator" object. We pass + // this response allocate object to triton when requesting + // inference. We can reuse this response allocator object for any + // number of inference requests. + allocator_ = nullptr; + THROW_IF_TRITON_ERROR(TRITONSERVER_ResponseAllocatorNew( + &allocator_, CPUAllocator, ResponseRelease, nullptr /* start_fn */)); +} + +TRITONSERVER_Error* +ModelExecutor::AsyncExecute( + TRITONSERVER_InferenceRequest* irequest, + std::future* future) +{ + // Perform inference by calling TRITONSERVER_ServerInferAsync. This + // call is asychronous and therefore returns immediately. The + // completion of the inference and delivery of the response is done + // by triton by calling the "response complete" callback functions + // (InferResponseComplete in this case). + auto p = new std::promise(); + *future = p->get_future(); + + RETURN_IF_ERROR(TRITONSERVER_InferenceRequestSetResponseCallback( + irequest, allocator_, nullptr /* response_allocator_userp */, + InferResponseComplete, reinterpret_cast(p))); + + RETURN_IF_ERROR( + TRITONSERVER_ServerInferAsync(server_, irequest, nullptr /* trace */)); + + return nullptr; // success +} + +}}} // namespace triton::backend::bls diff --git a/3rdparty/backend-r22.12/examples/backends/bls/src/bls_utils.h b/3rdparty/backend-r22.12/examples/backends/bls/src/bls_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..e5482e0adfbe3e2095419839ba32b1182f323370 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/backends/bls/src/bls_utils.h @@ -0,0 +1,98 @@ +// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include +#include "triton/backend/backend_common.h" +#include "triton/core/tritonbackend.h" +#include "triton/core/tritonserver.h" + +namespace triton { namespace backend { namespace bls { + +#define THROW_IF_TRITON_ERROR(X) \ + do { \ + TRITONSERVER_Error* tie_err__ = (X); \ + if (tie_err__ != nullptr) { \ + throw BLSBackendException(TRITONSERVER_ErrorMessage(tie_err__)); \ + } \ + } while (false) + +// +// BLSBackendException +// +// Exception thrown if error occurs in BLSBackend. +// +struct BLSBackendException : std::exception { + BLSBackendException(const std::string& message) : message_(message) {} + + const char* what() const throw() { return message_.c_str(); } + + std::string message_; +}; + +// Performs the allocations of output tensors. +TRITONSERVER_Error* CPUAllocator( + TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name, + size_t byte_size, TRITONSERVER_MemoryType preferred_memory_type, + int64_t preferred_memory_type_id, void* userp, void** buffer, + void** buffer_userp, TRITONSERVER_MemoryType* actual_memory_type, + int64_t* actual_memory_type_id); + +// Callback functions for server inference. +TRITONSERVER_Error* ResponseRelease( + TRITONSERVER_ResponseAllocator* allocator, void* buffer, void* buffer_userp, + size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id); +void InferRequestComplete( + TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp); +void InferResponseComplete( + TRITONSERVER_InferenceResponse* response, const uint32_t flags, + void* userp); + +// +// ModelExecutor +// +// Execute inference request on a model. +// +class ModelExecutor { + public: + ModelExecutor(TRITONSERVER_Server* server); + + // Performs async inference request. + TRITONSERVER_Error* AsyncExecute( + TRITONSERVER_InferenceRequest* irequest, + std::future* future); + + private: + // The server object that encapsulates all the functionality of the Triton + // server and allows access to the Triton server API. + TRITONSERVER_Server* server_; + + // The allocator object that will be used for allocating output tensors. + TRITONSERVER_ResponseAllocator* allocator_; +}; + +}}} // namespace triton::backend::bls diff --git a/3rdparty/backend-r22.12/examples/backends/bls/src/libtriton_bls.ldscript b/3rdparty/backend-r22.12/examples/backends/bls/src/libtriton_bls.ldscript new file mode 100644 index 0000000000000000000000000000000000000000..b7c0c7556550578b5ef1cc722cf7357602bdcfc5 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/backends/bls/src/libtriton_bls.ldscript @@ -0,0 +1,30 @@ +# Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +{ + global: + TRITONBACKEND_*; + local: *; +}; diff --git a/3rdparty/backend-r22.12/examples/backends/minimal/cmake/TutorialMinimalBackendConfig.cmake.in b/3rdparty/backend-r22.12/examples/backends/minimal/cmake/TutorialMinimalBackendConfig.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..2e408d0306e0b74611c53033eb255f58ef38a528 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/backends/minimal/cmake/TutorialMinimalBackendConfig.cmake.in @@ -0,0 +1,39 @@ +# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +include(CMakeFindDependencyMacro) + +get_filename_component( + TUTORIALMINIMALBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH +) + +list(APPEND CMAKE_MODULE_PATH ${TUTORIALMINIMALBACKEND_CMAKE_DIR}) + +if(NOT TARGET TutorialMinimalBackend::triton-minimal-backend) + include("${TUTORIALMINIMALBACKEND_CMAKE_DIR}/TutorialMinimalBackendTargets.cmake") +endif() + +set(TUTORIALMINIMALBACKEND_LIBRARIES TutorialMinimalBackend::triton-minimal-backend) diff --git a/3rdparty/backend-r22.12/examples/backends/minimal/src/libtriton_minimal.ldscript b/3rdparty/backend-r22.12/examples/backends/minimal/src/libtriton_minimal.ldscript new file mode 100644 index 0000000000000000000000000000000000000000..748714d16fd3a4d028e71216f33da78ff4e6dbe9 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/backends/minimal/src/libtriton_minimal.ldscript @@ -0,0 +1,30 @@ +# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +{ + global: + TRITONBACKEND_*; + local: *; +}; diff --git a/3rdparty/backend-r22.12/examples/backends/minimal/src/minimal.cc b/3rdparty/backend-r22.12/examples/backends/minimal/src/minimal.cc new file mode 100644 index 0000000000000000000000000000000000000000..6e29e3c78fde11c38c48bd8b2e7eb54985b967cd --- /dev/null +++ b/3rdparty/backend-r22.12/examples/backends/minimal/src/minimal.cc @@ -0,0 +1,434 @@ +// Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "triton/backend/backend_common.h" +#include "triton/backend/backend_input_collector.h" +#include "triton/backend/backend_model.h" +#include "triton/backend/backend_model_instance.h" +#include "triton/backend/backend_output_responder.h" +#include "triton/core/tritonbackend.h" + +namespace triton { namespace backend { namespace minimal { + +// +// Minimal backend that demonstrates the TRITONBACKEND API. This +// backend works for any model that has 1 input called "IN0" with +// INT32 datatype and shape [ 4 ] and 1 output called "OUT0" with +// INT32 datatype and shape [ 4 ]. The backend supports both batching +// and non-batching models. +// +// For each batch of requests, the backend returns the input tensor +// value in the output tensor. +// + +///////////// + +// +// ModelState +// +// State associated with a model that is using this backend. An object +// of this class is created and associated with each +// TRITONBACKEND_Model. ModelState is derived from BackendModel class +// provided in the backend utilities that provides many common +// functions. +// +class ModelState : public BackendModel { + public: + static TRITONSERVER_Error* Create( + TRITONBACKEND_Model* triton_model, ModelState** state); + virtual ~ModelState() = default; + + private: + ModelState(TRITONBACKEND_Model* triton_model) : BackendModel(triton_model) {} +}; + +TRITONSERVER_Error* +ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) +{ + try { + *state = new ModelState(triton_model); + } + catch (const BackendModelException& ex) { + RETURN_ERROR_IF_TRUE( + ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelException")); + RETURN_IF_ERROR(ex.err_); + } + + return nullptr; // success +} + +extern "C" { + +// Triton calls TRITONBACKEND_ModelInitialize when a model is loaded +// to allow the backend to create any state associated with the model, +// and to also examine the model configuration to determine if the +// configuration is suitable for the backend. Any errors reported by +// this function will prevent the model from loading. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model) +{ + // Create a ModelState object and associate it with the + // TRITONBACKEND_Model. If anything goes wrong with initialization + // of the model state then an error is returned and Triton will fail + // to load the model. + ModelState* model_state; + RETURN_IF_ERROR(ModelState::Create(model, &model_state)); + RETURN_IF_ERROR( + TRITONBACKEND_ModelSetState(model, reinterpret_cast(model_state))); + + return nullptr; // success +} + +// Triton calls TRITONBACKEND_ModelFinalize when a model is no longer +// needed. The backend should cleanup any state associated with the +// model. This function will not be called until all model instances +// of the model have been finalized. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model* model) +{ + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vstate)); + ModelState* model_state = reinterpret_cast(vstate); + delete model_state; + + return nullptr; // success +} + +} // extern "C" + +///////////// + +// +// ModelInstanceState +// +// State associated with a model instance. An object of this class is +// created and associated with each +// TRITONBACKEND_ModelInstance. ModelInstanceState is derived from +// BackendModelInstance class provided in the backend utilities that +// provides many common functions. +// +class ModelInstanceState : public BackendModelInstance { + public: + static TRITONSERVER_Error* Create( + ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state); + virtual ~ModelInstanceState() = default; + + // Get the state of the model that corresponds to this instance. + ModelState* StateForModel() const { return model_state_; } + + private: + ModelInstanceState( + ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance) + : BackendModelInstance(model_state, triton_model_instance), + model_state_(model_state) + { + } + + ModelState* model_state_; +}; + +TRITONSERVER_Error* +ModelInstanceState::Create( + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state) +{ + try { + *state = new ModelInstanceState(model_state, triton_model_instance); + } + catch (const BackendModelInstanceException& ex) { + RETURN_ERROR_IF_TRUE( + ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelInstanceException")); + RETURN_IF_ERROR(ex.err_); + } + + return nullptr; // success +} + +extern "C" { + +// Triton calls TRITONBACKEND_ModelInstanceInitialize when a model +// instance is created to allow the backend to initialize any state +// associated with the instance. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceInitialize(TRITONBACKEND_ModelInstance* instance) +{ + // Get the model state associated with this instance's model. + TRITONBACKEND_Model* model; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model)); + + void* vmodelstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vmodelstate)); + ModelState* model_state = reinterpret_cast(vmodelstate); + + // Create a ModelInstanceState object and associate it with the + // TRITONBACKEND_ModelInstance. + ModelInstanceState* instance_state; + RETURN_IF_ERROR( + ModelInstanceState::Create(model_state, instance, &instance_state)); + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState( + instance, reinterpret_cast(instance_state))); + + return nullptr; // success +} + +// Triton calls TRITONBACKEND_ModelInstanceFinalize when a model +// instance is no longer needed. The backend should cleanup any state +// associated with the model instance. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInstance* instance) +{ + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, &vstate)); + ModelInstanceState* instance_state = + reinterpret_cast(vstate); + delete instance_state; + + return nullptr; // success +} + +} // extern "C" + +///////////// + +extern "C" { + +// When Triton calls TRITONBACKEND_ModelInstanceExecute it is required +// that a backend create a response for each request in the batch. A +// response may be the output tensors required for that request or may +// be an error that is returned in the response. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceExecute( + TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests, + const uint32_t request_count) +{ + // Triton will not call this function simultaneously for the same + // 'instance'. But since this backend could be used by multiple + // instances from multiple models the implementation needs to handle + // multiple calls to this function at the same time (with different + // 'instance' objects). Best practice for a high-performance + // implementation is to avoid introducing mutex/lock and instead use + // only function-local and model-instance-specific state. + ModelInstanceState* instance_state; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState( + instance, reinterpret_cast(&instance_state))); + ModelState* model_state = instance_state->StateForModel(); + + // 'responses' is initialized as a parallel array to 'requests', + // with one TRITONBACKEND_Response object for each + // TRITONBACKEND_Request object. If something goes wrong while + // creating these response objects, the backend simply returns an + // error from TRITONBACKEND_ModelInstanceExecute, indicating to + // Triton that this backend did not create or send any responses and + // so it is up to Triton to create and send an appropriate error + // response for each request. RETURN_IF_ERROR is one of several + // useful macros for error handling that can be found in + // backend_common.h. + + std::vector responses; + responses.reserve(request_count); + for (uint32_t r = 0; r < request_count; ++r) { + TRITONBACKEND_Request* request = requests[r]; + TRITONBACKEND_Response* response; + RETURN_IF_ERROR(TRITONBACKEND_ResponseNew(&response, request)); + responses.push_back(response); + } + + // At this point, the backend takes ownership of 'requests', which + // means that it is responsible for sending a response for every + // request. From here, even if something goes wrong in processing, + // the backend must return 'nullptr' from this function to indicate + // success. Any errors and failures must be communicated via the + // response objects. + // + // To simplify error handling, the backend utilities manage + // 'responses' in a specific way and it is recommended that backends + // follow this same pattern. When an error is detected in the + // processing of a request, an appropriate error response is sent + // and the corresponding TRITONBACKEND_Response object within + // 'responses' is set to nullptr to indicate that the + // request/response has already been handled and no futher processing + // should be performed for that request. Even if all responses fail, + // the backend still allows execution to flow to the end of the + // function. RESPOND_AND_SET_NULL_IF_ERROR, and + // RESPOND_ALL_AND_SET_NULL_IF_ERROR are macros from + // backend_common.h that assist in this management of response + // objects. + + // The backend could iterate over the 'requests' and process each + // one separately. But for performance reasons it is usually + // preferred to create batched input tensors that are processed + // simultaneously. This is especially true for devices like GPUs + // that are capable of exploiting the large amount parallelism + // exposed by larger data sets. + // + // The backend utilities provide a "collector" to facilitate this + // batching process. The 'collector's ProcessTensor function will + // combine a tensor's value from each request in the batch into a + // single contiguous buffer. The buffer can be provided by the + // backend or 'collector' can create and manage it. In this backend, + // there is not a specific buffer into which the batch should be + // created, so use ProcessTensor arguments that cause collector to + // manage it. + + BackendInputCollector collector( + requests, request_count, &responses, model_state->TritonMemoryManager(), + false /* pinned_enabled */, nullptr /* stream*/); + + // To instruct ProcessTensor to "gather" the entire batch of IN0 + // input tensors into a single contiguous buffer in CPU memory, set + // the "allowed input types" to be the CPU ones (see tritonserver.h + // in the triton-inference-server/core repo for allowed memory + // types). + std::vector> allowed_input_types = + {{TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}}; + + const char* input_buffer; + size_t input_buffer_byte_size; + TRITONSERVER_MemoryType input_buffer_memory_type; + int64_t input_buffer_memory_type_id; + + RESPOND_ALL_AND_SET_NULL_IF_ERROR( + responses, request_count, + collector.ProcessTensor( + "IN0", nullptr /* existing_buffer */, + 0 /* existing_buffer_byte_size */, allowed_input_types, &input_buffer, + &input_buffer_byte_size, &input_buffer_memory_type, + &input_buffer_memory_type_id)); + + // Finalize the collector. If 'true' is returned, 'input_buffer' + // will not be valid until the backend synchronizes the CUDA + // stream or event that was used when creating the collector. For + // this backend, GPU is not supported and so no CUDA sync should + // be needed; so if 'true' is returned simply log an error. + const bool need_cuda_input_sync = collector.Finalize(); + if (need_cuda_input_sync) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + "'minimal' backend: unexpected CUDA sync required by collector"); + } + + // 'input_buffer' contains the batched "IN0" tensor. The backend can + // implement whatever logic is necesary to produce "OUT0". This + // backend simply returns the IN0 value in OUT0 so no actual + // computation is needed. + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("model ") + model_state->Name() + ": requests in batch " + + std::to_string(request_count)) + .c_str()); + std::string tstr; + IGNORE_ERROR(BufferAsTypedString( + tstr, input_buffer, input_buffer_byte_size, TRITONSERVER_TYPE_INT32)); + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("batched IN0 value: ") + tstr).c_str()); + + const char* output_buffer = input_buffer; + TRITONSERVER_MemoryType output_buffer_memory_type = input_buffer_memory_type; + int64_t output_buffer_memory_type_id = input_buffer_memory_type_id; + + // This backend supports models that batch along the first dimension + // and those that don't batch. For non-batch models the output shape + // will be [ 4 ]. For batch models the output shape will be [ -1, 4 + // ] and the backend "responder" utility below will set the + // appropriate batch dimension value for each response. + std::vector output_batch_shape; + bool supports_first_dim_batching; + RESPOND_ALL_AND_SET_NULL_IF_ERROR( + responses, request_count, + model_state->SupportsFirstDimBatching(&supports_first_dim_batching)); + if (supports_first_dim_batching) { + output_batch_shape.push_back(-1); + } + output_batch_shape.push_back(4); + + // Because the OUT0 values are concatenated into a single contiguous + // 'output_buffer', the backend must "scatter" them out to the + // individual response OUT0 tensors. The backend utilities provide + // a "responder" to facilitate this scattering process. + + // The 'responders's ProcessTensor function will copy the portion of + // 'output_buffer' corresonding to each request's output into the + // response for that request. + + BackendOutputResponder responder( + requests, request_count, &responses, model_state->TritonMemoryManager(), + supports_first_dim_batching, false /* pinned_enabled */, + nullptr /* stream*/); + + responder.ProcessTensor( + "OUT0", TRITONSERVER_TYPE_INT32, output_batch_shape, output_buffer, + output_buffer_memory_type, output_buffer_memory_type_id); + + // Finalize the responder. If 'true' is returned, the OUT0 + // tensors' data will not be valid until the backend synchronizes + // the CUDA stream or event that was used when creating the + // responder. For this backend, GPU is not supported and so no + // CUDA sync should be needed; so if 'true' is returned simply log + // an error. + const bool need_cuda_output_sync = responder.Finalize(); + if (need_cuda_output_sync) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + "'minimal' backend: unexpected CUDA sync required by responder"); + } + + // Send all the responses that haven't already been sent because of + // an earlier error. + for (auto& response : responses) { + if (response != nullptr) { + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend( + response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr), + "failed to send response"); + } + } + + // Done with the request objects so release them. + for (uint32_t r = 0; r < request_count; ++r) { + auto& request = requests[r]; + LOG_IF_ERROR( + TRITONBACKEND_RequestRelease(request, TRITONSERVER_REQUEST_RELEASE_ALL), + "failed releasing request"); + } + + return nullptr; // success +} + +} // extern "C" + +}}} // namespace triton::backend::minimal diff --git a/3rdparty/backend-r22.12/examples/backends/recommended/cmake/TutorialRecommendedBackendConfig.cmake.in b/3rdparty/backend-r22.12/examples/backends/recommended/cmake/TutorialRecommendedBackendConfig.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..4007f9f8d7a4f302be52acc868532b4929739f48 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/backends/recommended/cmake/TutorialRecommendedBackendConfig.cmake.in @@ -0,0 +1,39 @@ +# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +include(CMakeFindDependencyMacro) + +get_filename_component( + TUTORIALRECOMMENDEDBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH +) + +list(APPEND CMAKE_MODULE_PATH ${TUTORIALRECOMMENDEDBACKEND_CMAKE_DIR}) + +if(NOT TARGET TutorialRecommendedBackend::triton-recommended-backend) + include("${TUTORIALRECOMMENDEDBACKEND_CMAKE_DIR}/TutorialRecommendedBackendTargets.cmake") +endif() + +set(TUTORIALRECOMMENDEDBACKEND_LIBRARIES TutorialRecommendedBackend::triton-recommended-backend) diff --git a/3rdparty/backend-r22.12/examples/backends/recommended/src/libtriton_recommended.ldscript b/3rdparty/backend-r22.12/examples/backends/recommended/src/libtriton_recommended.ldscript new file mode 100644 index 0000000000000000000000000000000000000000..748714d16fd3a4d028e71216f33da78ff4e6dbe9 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/backends/recommended/src/libtriton_recommended.ldscript @@ -0,0 +1,30 @@ +# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +{ + global: + TRITONBACKEND_*; + local: *; +}; diff --git a/3rdparty/backend-r22.12/examples/backends/recommended/src/recommended.cc b/3rdparty/backend-r22.12/examples/backends/recommended/src/recommended.cc new file mode 100644 index 0000000000000000000000000000000000000000..02f46724a96c503c49be2a31288fef342d61bb35 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/backends/recommended/src/recommended.cc @@ -0,0 +1,750 @@ +// Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "triton/backend/backend_common.h" +#include "triton/backend/backend_input_collector.h" +#include "triton/backend/backend_model.h" +#include "triton/backend/backend_model_instance.h" +#include "triton/backend/backend_output_responder.h" +#include "triton/core/tritonbackend.h" + +namespace triton { namespace backend { namespace recommended { + +// +// Backend that demonstrates the TRITONBACKEND API. This backend works +// for any model that has 1 input with any datatype and any shape and +// 1 output with the same shape and datatype as the input. The backend +// supports both batching and non-batching models. +// +// For each batch of requests, the backend returns the input tensor +// value in the output tensor. +// + +///////////// + +extern "C" { + +// Triton calls TRITONBACKEND_Initialize when a backend is loaded into +// Triton to allow the backend to create and initialize any state that +// is intended to be shared across all models and model instances that +// use the backend. The backend should also verify version +// compatibility with Triton in this function. +// +TRITONSERVER_Error* +TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) +{ + const char* cname; + RETURN_IF_ERROR(TRITONBACKEND_BackendName(backend, &cname)); + std::string name(cname); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("TRITONBACKEND_Initialize: ") + name).c_str()); + + // Check the backend API version that Triton supports vs. what this + // backend was compiled against. Make sure that the Triton major + // version is the same and the minor version is >= what this backend + // uses. + uint32_t api_version_major, api_version_minor; + RETURN_IF_ERROR( + TRITONBACKEND_ApiVersion(&api_version_major, &api_version_minor)); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("Triton TRITONBACKEND API version: ") + + std::to_string(api_version_major) + "." + + std::to_string(api_version_minor)) + .c_str()); + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("'") + name + "' TRITONBACKEND API version: " + + std::to_string(TRITONBACKEND_API_VERSION_MAJOR) + "." + + std::to_string(TRITONBACKEND_API_VERSION_MINOR)) + .c_str()); + + if ((api_version_major != TRITONBACKEND_API_VERSION_MAJOR) || + (api_version_minor < TRITONBACKEND_API_VERSION_MINOR)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "triton backend API version does not support this backend"); + } + + // The backend configuration may contain information needed by the + // backend, such as tritonserver command-line arguments. This + // backend doesn't use any such configuration but for this example + // print whatever is available. + TRITONSERVER_Message* backend_config_message; + RETURN_IF_ERROR( + TRITONBACKEND_BackendConfig(backend, &backend_config_message)); + + const char* buffer; + size_t byte_size; + RETURN_IF_ERROR(TRITONSERVER_MessageSerializeToJson( + backend_config_message, &buffer, &byte_size)); + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("backend configuration:\n") + buffer).c_str()); + + // This backend does not require any "global" state but as an + // example create a string to demonstrate. + std::string* state = new std::string("backend state"); + RETURN_IF_ERROR( + TRITONBACKEND_BackendSetState(backend, reinterpret_cast(state))); + + return nullptr; // success +} + +// Triton calls TRITONBACKEND_Finalize when a backend is no longer +// needed. +// +TRITONSERVER_Error* +TRITONBACKEND_Finalize(TRITONBACKEND_Backend* backend) +{ + // Delete the "global" state associated with the backend. + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_BackendState(backend, &vstate)); + std::string* state = reinterpret_cast(vstate); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("TRITONBACKEND_Finalize: state is '") + *state + "'") + .c_str()); + + delete state; + + return nullptr; // success +} + +} // extern "C" + +///////////// + +// +// ModelState +// +// State associated with a model that is using this backend. An object +// of this class is created and associated with each +// TRITONBACKEND_Model. ModelState is derived from BackendModel class +// provided in the backend utilities that provides many common +// functions. +// +class ModelState : public BackendModel { + public: + static TRITONSERVER_Error* Create( + TRITONBACKEND_Model* triton_model, ModelState** state); + virtual ~ModelState() = default; + + // Name of the input and output tensor + const std::string& InputTensorName() const { return input_name_; } + const std::string& OutputTensorName() const { return output_name_; } + + // Datatype of the input and output tensor + TRITONSERVER_DataType TensorDataType() const { return datatype_; } + + // Shape of the input and output tensor as given in the model + // configuration file. This shape will not include the batch + // dimension (if the model has one). + const std::vector& TensorNonBatchShape() const { return nb_shape_; } + + // Shape of the input and output tensor, including the batch + // dimension (if the model has one). This method cannot be called + // until the model is completely loaded and initialized, including + // all instances of the model. In practice, this means that backend + // should only call it in TRITONBACKEND_ModelInstanceExecute. + TRITONSERVER_Error* TensorShape(std::vector& shape); + + // Validate that this model is supported by this backend. + TRITONSERVER_Error* ValidateModelConfig(); + + private: + ModelState(TRITONBACKEND_Model* triton_model); + + std::string input_name_; + std::string output_name_; + + TRITONSERVER_DataType datatype_; + + bool shape_initialized_; + std::vector nb_shape_; + std::vector shape_; +}; + +ModelState::ModelState(TRITONBACKEND_Model* triton_model) + : BackendModel(triton_model), shape_initialized_(false) +{ + // Validate that the model's configuration matches what is supported + // by this backend. + THROW_IF_BACKEND_MODEL_ERROR(ValidateModelConfig()); +} + +TRITONSERVER_Error* +ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) +{ + try { + *state = new ModelState(triton_model); + } + catch (const BackendModelException& ex) { + RETURN_ERROR_IF_TRUE( + ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelException")); + RETURN_IF_ERROR(ex.err_); + } + + return nullptr; // success +} + +TRITONSERVER_Error* +ModelState::TensorShape(std::vector& shape) +{ + // This backend supports models that batch along the first dimension + // and those that don't batch. For non-batch models the output shape + // will be the shape from the model configuration. For batch models + // the output shape will be the shape from the model configuration + // prepended with [ -1 ] to represent the batch dimension. The + // backend "responder" utility used below will set the appropriate + // batch dimension value for each response. The shape needs to be + // initialized lazily because the SupportsFirstDimBatching function + // cannot be used until the model is completely loaded. + if (!shape_initialized_) { + bool supports_first_dim_batching; + RETURN_IF_ERROR(SupportsFirstDimBatching(&supports_first_dim_batching)); + if (supports_first_dim_batching) { + shape_.push_back(-1); + } + + shape_.insert(shape_.end(), nb_shape_.begin(), nb_shape_.end()); + shape_initialized_ = true; + } + + shape = shape_; + + return nullptr; // success +} + +TRITONSERVER_Error* +ModelState::ValidateModelConfig() +{ + // If verbose logging is enabled, dump the model's configuration as + // JSON into the console output. + if (TRITONSERVER_LogIsEnabled(TRITONSERVER_LOG_VERBOSE)) { + common::TritonJson::WriteBuffer buffer; + RETURN_IF_ERROR(ModelConfig().PrettyWrite(&buffer)); + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("model configuration:\n") + buffer.Contents()).c_str()); + } + + // ModelConfig is the model configuration as a TritonJson + // object. Use the TritonJson utilities to parse the JSON and + // determine if the configuration is supported by this backend. + common::TritonJson::Value inputs, outputs; + RETURN_IF_ERROR(ModelConfig().MemberAsArray("input", &inputs)); + RETURN_IF_ERROR(ModelConfig().MemberAsArray("output", &outputs)); + + // The model must have exactly 1 input and 1 output. + RETURN_ERROR_IF_FALSE( + inputs.ArraySize() == 1, TRITONSERVER_ERROR_INVALID_ARG, + std::string("model configuration must have 1 input")); + RETURN_ERROR_IF_FALSE( + outputs.ArraySize() == 1, TRITONSERVER_ERROR_INVALID_ARG, + std::string("model configuration must have 1 output")); + + common::TritonJson::Value input, output; + RETURN_IF_ERROR(inputs.IndexAsObject(0, &input)); + RETURN_IF_ERROR(outputs.IndexAsObject(0, &output)); + + // Record the input and output name in the model state. + const char* input_name; + size_t input_name_len; + RETURN_IF_ERROR(input.MemberAsString("name", &input_name, &input_name_len)); + input_name_ = std::string(input_name); + + const char* output_name; + size_t output_name_len; + RETURN_IF_ERROR( + output.MemberAsString("name", &output_name, &output_name_len)); + output_name_ = std::string(output_name); + + // Input and output must have same datatype + std::string input_dtype, output_dtype; + RETURN_IF_ERROR(input.MemberAsString("data_type", &input_dtype)); + RETURN_IF_ERROR(output.MemberAsString("data_type", &output_dtype)); + RETURN_ERROR_IF_FALSE( + input_dtype == output_dtype, TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected input and output datatype to match, got ") + + input_dtype + " and " + output_dtype); + datatype_ = ModelConfigDataTypeToTritonServerDataType(input_dtype); + + // Input and output must have same shape. Reshape is not supported + // on either input or output so flag an error is the model + // configuration uses it. + triton::common::TritonJson::Value reshape; + RETURN_ERROR_IF_TRUE( + input.Find("reshape", &reshape), TRITONSERVER_ERROR_UNSUPPORTED, + std::string("reshape not supported for input tensor")); + RETURN_ERROR_IF_TRUE( + output.Find("reshape", &reshape), TRITONSERVER_ERROR_UNSUPPORTED, + std::string("reshape not supported for output tensor")); + + std::vector input_shape, output_shape; + RETURN_IF_ERROR(backend::ParseShape(input, "dims", &input_shape)); + RETURN_IF_ERROR(backend::ParseShape(output, "dims", &output_shape)); + + RETURN_ERROR_IF_FALSE( + input_shape == output_shape, TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected input and output shape to match, got ") + + backend::ShapeToString(input_shape) + " and " + + backend::ShapeToString(output_shape)); + + nb_shape_ = input_shape; + + return nullptr; // success +} + +extern "C" { + +// Triton calls TRITONBACKEND_ModelInitialize when a model is loaded +// to allow the backend to create any state associated with the model, +// and to also examine the model configuration to determine if the +// configuration is suitable for the backend. Any errors reported by +// this function will prevent the model from loading. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model) +{ + // Create a ModelState object and associate it with the + // TRITONBACKEND_Model. If anything goes wrong with initialization + // of the model state then an error is returned and Triton will fail + // to load the model. + ModelState* model_state; + RETURN_IF_ERROR(ModelState::Create(model, &model_state)); + RETURN_IF_ERROR( + TRITONBACKEND_ModelSetState(model, reinterpret_cast(model_state))); + + return nullptr; // success +} + +// Triton calls TRITONBACKEND_ModelFinalize when a model is no longer +// needed. The backend should cleanup any state associated with the +// model. This function will not be called until all model instances +// of the model have been finalized. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model* model) +{ + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vstate)); + ModelState* model_state = reinterpret_cast(vstate); + delete model_state; + + return nullptr; // success +} + +} // extern "C" + +///////////// + +// +// ModelInstanceState +// +// State associated with a model instance. An object of this class is +// created and associated with each +// TRITONBACKEND_ModelInstance. ModelInstanceState is derived from +// BackendModelInstance class provided in the backend utilities that +// provides many common functions. +// +class ModelInstanceState : public BackendModelInstance { + public: + static TRITONSERVER_Error* Create( + ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state); + virtual ~ModelInstanceState() = default; + + // Get the state of the model that corresponds to this instance. + ModelState* StateForModel() const { return model_state_; } + + private: + ModelInstanceState( + ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance) + : BackendModelInstance(model_state, triton_model_instance), + model_state_(model_state) + { + } + + ModelState* model_state_; +}; + +TRITONSERVER_Error* +ModelInstanceState::Create( + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state) +{ + try { + *state = new ModelInstanceState(model_state, triton_model_instance); + } + catch (const BackendModelInstanceException& ex) { + RETURN_ERROR_IF_TRUE( + ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelInstanceException")); + RETURN_IF_ERROR(ex.err_); + } + + return nullptr; // success +} + +extern "C" { + +// Triton calls TRITONBACKEND_ModelInstanceInitialize when a model +// instance is created to allow the backend to initialize any state +// associated with the instance. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceInitialize(TRITONBACKEND_ModelInstance* instance) +{ + // Get the model state associated with this instance's model. + TRITONBACKEND_Model* model; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model)); + + void* vmodelstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vmodelstate)); + ModelState* model_state = reinterpret_cast(vmodelstate); + + // Create a ModelInstanceState object and associate it with the + // TRITONBACKEND_ModelInstance. + ModelInstanceState* instance_state; + RETURN_IF_ERROR( + ModelInstanceState::Create(model_state, instance, &instance_state)); + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState( + instance, reinterpret_cast(instance_state))); + + return nullptr; // success +} + +// Triton calls TRITONBACKEND_ModelInstanceFinalize when a model +// instance is no longer needed. The backend should cleanup any state +// associated with the model instance. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInstance* instance) +{ + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, &vstate)); + ModelInstanceState* instance_state = + reinterpret_cast(vstate); + delete instance_state; + + return nullptr; // success +} + +} // extern "C" + +///////////// + +extern "C" { + +// When Triton calls TRITONBACKEND_ModelInstanceExecute it is required +// that a backend create a response for each request in the batch. A +// response may be the output tensors required for that request or may +// be an error that is returned in the response. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceExecute( + TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests, + const uint32_t request_count) +{ + // Collect various timestamps during the execution of this batch or + // requests. These values are reported below before returning from + // the function. + + uint64_t exec_start_ns = 0; + SET_TIMESTAMP(exec_start_ns); + + // Triton will not call this function simultaneously for the same + // 'instance'. But since this backend could be used by multiple + // instances from multiple models the implementation needs to handle + // multiple calls to this function at the same time (with different + // 'instance' objects). Best practice for a high-performance + // implementation is to avoid introducing mutex/lock and instead use + // only function-local and model-instance-specific state. + ModelInstanceState* instance_state; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState( + instance, reinterpret_cast(&instance_state))); + ModelState* model_state = instance_state->StateForModel(); + + // 'responses' is initialized as a parallel array to 'requests', + // with one TRITONBACKEND_Response object for each + // TRITONBACKEND_Request object. If something goes wrong while + // creating these response objects, the backend simply returns an + // error from TRITONBACKEND_ModelInstanceExecute, indicating to + // Triton that this backend did not create or send any responses and + // so it is up to Triton to create and send an appropriate error + // response for each request. RETURN_IF_ERROR is one of several + // useful macros for error handling that can be found in + // backend_common.h. + + std::vector responses; + responses.reserve(request_count); + for (uint32_t r = 0; r < request_count; ++r) { + TRITONBACKEND_Request* request = requests[r]; + TRITONBACKEND_Response* response; + RETURN_IF_ERROR(TRITONBACKEND_ResponseNew(&response, request)); + responses.push_back(response); + } + + // At this point, the backend takes ownership of 'requests', which + // means that it is responsible for sending a response for every + // request. From here, even if something goes wrong in processing, + // the backend must return 'nullptr' from this function to indicate + // success. Any errors and failures must be communicated via the + // response objects. + // + // To simplify error handling, the backend utilities manage + // 'responses' in a specific way and it is recommended that backends + // follow this same pattern. When an error is detected in the + // processing of a request, an appropriate error response is sent + // and the corresponding TRITONBACKEND_Response object within + // 'responses' is set to nullptr to indicate that the + // request/response has already been handled and no futher processing + // should be performed for that request. Even if all responses fail, + // the backend still allows execution to flow to the end of the + // function so that statistics are correctly reported by the calls + // to TRITONBACKEND_ModelInstanceReportStatistics and + // TRITONBACKEND_ModelInstanceReportBatchStatistics. + // RESPOND_AND_SET_NULL_IF_ERROR, and + // RESPOND_ALL_AND_SET_NULL_IF_ERROR are macros from + // backend_common.h that assist in this management of response + // objects. + + // The backend could iterate over the 'requests' and process each + // one separately. But for performance reasons it is usually + // preferred to create batched input tensors that are processed + // simultaneously. This is especially true for devices like GPUs + // that are capable of exploiting the large amount parallelism + // exposed by larger data sets. + // + // The backend utilities provide a "collector" to facilitate this + // batching process. The 'collector's ProcessTensor function will + // combine a tensor's value from each request in the batch into a + // single contiguous buffer. The buffer can be provided by the + // backend or 'collector' can create and manage it. In this backend, + // there is not a specific buffer into which the batch should be + // created, so use ProcessTensor arguments that cause collector to + // manage it. ProcessTensor does NOT support TRITONSERVER_TYPE_BYTES + // data type. + + BackendInputCollector collector( + requests, request_count, &responses, model_state->TritonMemoryManager(), + false /* pinned_enabled */, nullptr /* stream*/); + + // To instruct ProcessTensor to "gather" the entire batch of input + // tensors into a single contiguous buffer in CPU memory, set the + // "allowed input types" to be the CPU ones (see tritonserver.h in + // the triton-inference-server/core repo for allowed memory types). + std::vector> allowed_input_types = + {{TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}}; + + const char* input_buffer; + size_t input_buffer_byte_size; + TRITONSERVER_MemoryType input_buffer_memory_type; + int64_t input_buffer_memory_type_id; + + RESPOND_ALL_AND_SET_NULL_IF_ERROR( + responses, request_count, + collector.ProcessTensor( + model_state->InputTensorName().c_str(), nullptr /* existing_buffer */, + 0 /* existing_buffer_byte_size */, allowed_input_types, &input_buffer, + &input_buffer_byte_size, &input_buffer_memory_type, + &input_buffer_memory_type_id)); + + // Finalize the collector. If 'true' is returned, 'input_buffer' + // will not be valid until the backend synchronizes the CUDA + // stream or event that was used when creating the collector. For + // this backend, GPU is not supported and so no CUDA sync should + // be needed; so if 'true' is returned simply log an error. + const bool need_cuda_input_sync = collector.Finalize(); + if (need_cuda_input_sync) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + "'recommended' backend: unexpected CUDA sync required by collector"); + } + + // 'input_buffer' contains the batched input tensor. The backend can + // implement whatever logic is necessary to produce the output + // tensor. This backend simply logs the input tensor value and then + // returns the input tensor value in the output tensor so no actual + // computation is needed. + + uint64_t compute_start_ns = 0; + SET_TIMESTAMP(compute_start_ns); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("model ") + model_state->Name() + ": requests in batch " + + std::to_string(request_count)) + .c_str()); + std::string tstr; + IGNORE_ERROR(BufferAsTypedString( + tstr, input_buffer, input_buffer_byte_size, + model_state->TensorDataType())); + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("batched " + model_state->InputTensorName() + " value: ") + + tstr) + .c_str()); + + const char* output_buffer = input_buffer; + TRITONSERVER_MemoryType output_buffer_memory_type = input_buffer_memory_type; + int64_t output_buffer_memory_type_id = input_buffer_memory_type_id; + + uint64_t compute_end_ns = 0; + SET_TIMESTAMP(compute_end_ns); + + bool supports_first_dim_batching; + RESPOND_ALL_AND_SET_NULL_IF_ERROR( + responses, request_count, + model_state->SupportsFirstDimBatching(&supports_first_dim_batching)); + + std::vector tensor_shape; + RESPOND_ALL_AND_SET_NULL_IF_ERROR( + responses, request_count, model_state->TensorShape(tensor_shape)); + + // Because the output tensor values are concatenated into a single + // contiguous 'output_buffer', the backend must "scatter" them out + // to the individual response output tensors. The backend utilities + // provide a "responder" to facilitate this scattering process. + // BackendOutputResponder does NOT support TRITONSERVER_TYPE_BYTES + // data type. + + // The 'responders's ProcessTensor function will copy the portion of + // 'output_buffer' corresonding to each request's output into the + // response for that request. + + BackendOutputResponder responder( + requests, request_count, &responses, model_state->TritonMemoryManager(), + supports_first_dim_batching, false /* pinned_enabled */, + nullptr /* stream*/); + + responder.ProcessTensor( + model_state->OutputTensorName().c_str(), model_state->TensorDataType(), + tensor_shape, output_buffer, output_buffer_memory_type, + output_buffer_memory_type_id); + + // Finalize the responder. If 'true' is returned, the output + // tensors' data will not be valid until the backend synchronizes + // the CUDA stream or event that was used when creating the + // responder. For this backend, GPU is not supported and so no CUDA + // sync should be needed; so if 'true' is returned simply log an + // error. + const bool need_cuda_output_sync = responder.Finalize(); + if (need_cuda_output_sync) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + "'recommended' backend: unexpected CUDA sync required by responder"); + } + + // Send all the responses that haven't already been sent because of + // an earlier error. + for (auto& response : responses) { + if (response != nullptr) { + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend( + response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr), + "failed to send response"); + } + } + + uint64_t exec_end_ns = 0; + SET_TIMESTAMP(exec_end_ns); + +#ifdef TRITON_ENABLE_STATS + // For batch statistics need to know the total batch size of the + // requests. This is not necessarily just the number of requests, + // because if the model supports batching then any request can be a + // batched request itself. + size_t total_batch_size = 0; + if (!supports_first_dim_batching) { + total_batch_size = request_count; + } else { + for (uint32_t r = 0; r < request_count; ++r) { + auto& request = requests[r]; + TRITONBACKEND_Input* input = nullptr; + LOG_IF_ERROR( + TRITONBACKEND_RequestInputByIndex(request, 0 /* index */, &input), + "failed getting request input"); + if (input != nullptr) { + const int64_t* shape = nullptr; + LOG_IF_ERROR( + TRITONBACKEND_InputProperties( + input, nullptr, nullptr, &shape, nullptr, nullptr, nullptr), + "failed getting input properties"); + if (shape != nullptr) { + total_batch_size += shape[0]; + } + } + } + } +#else + (void)exec_start_ns; + (void)exec_end_ns; + (void)compute_start_ns; + (void)compute_end_ns; +#endif // TRITON_ENABLE_STATS + + // Report statistics for each request, and then release the request. + for (uint32_t r = 0; r < request_count; ++r) { + auto& request = requests[r]; + +#ifdef TRITON_ENABLE_STATS + LOG_IF_ERROR( + TRITONBACKEND_ModelInstanceReportStatistics( + instance_state->TritonModelInstance(), request, + (responses[r] != nullptr) /* success */, exec_start_ns, + compute_start_ns, compute_end_ns, exec_end_ns), + "failed reporting request statistics"); +#endif // TRITON_ENABLE_STATS + + LOG_IF_ERROR( + TRITONBACKEND_RequestRelease(request, TRITONSERVER_REQUEST_RELEASE_ALL), + "failed releasing request"); + } + +#ifdef TRITON_ENABLE_STATS + // Report batch statistics. + LOG_IF_ERROR( + TRITONBACKEND_ModelInstanceReportBatchStatistics( + instance_state->TritonModelInstance(), total_batch_size, + exec_start_ns, compute_start_ns, compute_end_ns, exec_end_ns), + "failed reporting batch request statistics"); +#endif // TRITON_ENABLE_STATS + + return nullptr; // success +} + +} // extern "C" + +}}} // namespace triton::backend::recommended diff --git a/3rdparty/backend-r22.12/examples/clients/bls_client b/3rdparty/backend-r22.12/examples/clients/bls_client new file mode 100644 index 0000000000000000000000000000000000000000..82090901d7b06ee17e5d511eae775c4b84a451d5 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/clients/bls_client @@ -0,0 +1,86 @@ +#!/usr/bin/python +# Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys +import argparse +import numpy as np +import tritonhttpclient as httpclient +from tritonclientutils import np_to_triton_dtype + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-u', + '--url', + type=str, + required=False, + default='localhost:8000', + help='Inference server URL. Default is localhost:8000.') + FLAGS = parser.parse_args() + + model_name = "bls_fp32" + shape = [16] + with httpclient.InferenceServerClient(url=FLAGS.url) as client: + input0_data = np.random.rand(*shape).astype(np.float32) + input1_data = np.random.rand(*shape).astype(np.float32) + inputs = [ + httpclient.InferInput("INPUT0", input0_data.shape, + np_to_triton_dtype(input0_data.dtype)), + httpclient.InferInput("INPUT1", input1_data.shape, + np_to_triton_dtype(input1_data.dtype)), + ] + + inputs[0].set_data_from_numpy(input0_data) + inputs[1].set_data_from_numpy(input1_data) + + outputs = [ + httpclient.InferRequestedOutput("OUTPUT0"), + httpclient.InferRequestedOutput("OUTPUT1"), + ] + response = client.infer(model_name, + inputs, + request_id=str(1), + outputs=outputs) + + result = response.get_response() + output0_data = response.as_numpy("OUTPUT0") + output1_data = response.as_numpy("OUTPUT1") + + print("INPUT0 ({}) + INPUT1 ({}) = OUTPUT0 ({})".format( + input0_data, input1_data, output0_data)) + print("INPUT0 ({}) - INPUT1 ({}) = OUTPUT1 ({})".format( + input0_data, input1_data, output1_data)) + + if not np.allclose(input0_data + input1_data, output0_data): + print("error: incorrect sum") + sys.exit(1) + + if not np.allclose(input0_data - input1_data, output1_data): + print("error: incorrect difference") + sys.exit(1) + + print('\nPASS') + sys.exit(0) diff --git a/3rdparty/backend-r22.12/examples/clients/minimal_client b/3rdparty/backend-r22.12/examples/clients/minimal_client new file mode 100644 index 0000000000000000000000000000000000000000..ffead3480ffa50270d73495d8a9358e2f1f2ea79 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/clients/minimal_client @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import numpy as np + +import tritonclient.http as httpclient +from tritonclient.utils import InferenceServerException + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-u', + '--url', + type=str, + required=False, + default='localhost:8000', + help='Inference server URL. Default is localhost:8000.') + FLAGS = parser.parse_args() + + # For the HTTP client, need to specify large enough concurrency to + # issue all the inference requests to the server in parallel. For + # this example we want to be able to send 2 requests concurrently. + try: + concurrent_request_count = 2 + triton_client = httpclient.InferenceServerClient( + url=FLAGS.url, concurrency=concurrent_request_count) + except Exception as e: + print("channel creation failed: " + str(e)) + sys.exit(1) + + # First send a single request to the nonbatching model. + print('=========') + input0_data = np.array([ 1, 2, 3, 4 ], dtype=np.int32) + print('Sending request to nonbatching model: IN0 = {}'.format(input0_data)) + + inputs = [ httpclient.InferInput('IN0', [4], "INT32") ] + inputs[0].set_data_from_numpy(input0_data) + result = triton_client.infer('nonbatching', inputs) + + print('Response: {}'.format(result.get_response())) + print('OUT0 = {}'.format(result.as_numpy('OUT0'))) + + # Send 2 requests to the batching model. Because these are sent + # asynchronously and Triton's dynamic batcher is configured to + # delay up to 5 seconds when forming a batch for this model, we + # expect these 2 requests to be batched within Triton and sent to + # the minimal backend as a single batch. + print('\n=========') + async_requests = [] + + input0_data = np.array([[ 10, 11, 12, 13 ]], dtype=np.int32) + print('Sending request to batching model: IN0 = {}'.format(input0_data)) + inputs = [ httpclient.InferInput('IN0', [1, 4], "INT32") ] + inputs[0].set_data_from_numpy(input0_data) + async_requests.append(triton_client.async_infer('batching', inputs)) + + input0_data = np.array([[ 20, 21, 22, 23 ]], dtype=np.int32) + print('Sending request to batching model: IN0 = {}'.format(input0_data)) + inputs = [ httpclient.InferInput('IN0', [1, 4], "INT32") ] + inputs[0].set_data_from_numpy(input0_data) + async_requests.append(triton_client.async_infer('batching', inputs)) + + for async_request in async_requests: + # Get the result from the initiated asynchronous inference + # request. This call will block till the server responds. + result = async_request.get_result() + print('Response: {}'.format(result.get_response())) + print('OUT0 = {}'.format(result.as_numpy('OUT0'))) diff --git a/3rdparty/backend-r22.12/examples/clients/recommended_client b/3rdparty/backend-r22.12/examples/clients/recommended_client new file mode 100644 index 0000000000000000000000000000000000000000..4a586d2b6d7247e1d715dd7d7e0e41a1a877227f --- /dev/null +++ b/3rdparty/backend-r22.12/examples/clients/recommended_client @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import numpy as np + +import tritonclient.http as httpclient +from tritonclient.utils import InferenceServerException + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-u', + '--url', + type=str, + required=False, + default='localhost:8000', + help='Inference server URL. Default is localhost:8000.') + FLAGS = parser.parse_args() + + # For the HTTP client, need to specify large enough concurrency to + # issue all the inference requests to the server in parallel. For + # this example we want to be able to send 2 requests concurrently. + try: + concurrent_request_count = 2 + triton_client = httpclient.InferenceServerClient( + url=FLAGS.url, concurrency=concurrent_request_count) + except Exception as e: + print("channel creation failed: " + str(e)) + sys.exit(1) + + # Send 2 requests to the batching model. Because these are sent + # asynchronously and Triton's dynamic batcher is configured to + # delay up to 5 seconds when forming a batch for this model, we + # expect these 2 requests to be batched within Triton and sent to + # the backend as a single batch. + # + # The recommended backend can handle any model with 1 input and 1 + # output as long as the input and output datatype and shape are + # the same. The batching model uses datatype FP32 and shape + # [ 4, 4 ]. + print('\n=========') + async_requests = [] + + input0_data = np.array([[[ 1.0, 1.1, 1.2, 1.3 ], + [ 2.0, 2.1, 2.2, 2.3 ], + [ 3.0, 3.1, 3.2, 3.3 ], + [ 4.0, 4.1, 4.2, 4.3 ]]], dtype=np.float32) + print('Sending request to batching model: input = {}'.format(input0_data)) + inputs = [ httpclient.InferInput('INPUT', [1, 4, 4], "FP32") ] + inputs[0].set_data_from_numpy(input0_data) + async_requests.append(triton_client.async_infer('batching', inputs)) + + input0_data = np.array([[[ 10.0, 10.1, 10.2, 10.3 ], + [ 20.0, 20.1, 20.2, 20.3 ], + [ 30.0, 30.1, 30.2, 30.3 ], + [ 40.0, 40.1, 40.2, 40.3 ]]], dtype=np.float32) + print('Sending request to batching model: input = {}'.format(input0_data)) + inputs = [ httpclient.InferInput('INPUT', [1, 4, 4], "FP32") ] + inputs[0].set_data_from_numpy(input0_data) + async_requests.append(triton_client.async_infer('batching', inputs)) + + for async_request in async_requests: + # Get the result from the initiated asynchronous inference + # request. This call will block till the server responds. + result = async_request.get_result() + print('Response: {}'.format(result.get_response())) + print('OUTPUT = {}'.format(result.as_numpy('OUTPUT'))) diff --git a/3rdparty/backend-r22.12/examples/model_repos/bls_models/addsub_python/1/model.py b/3rdparty/backend-r22.12/examples/model_repos/bls_models/addsub_python/1/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f0ef2c8b205c8b8425074d04617a8b624fd93347 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/model_repos/bls_models/addsub_python/1/model.py @@ -0,0 +1,74 @@ +# Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import triton_python_backend_utils as pb_utils + +# This model calculates the sum and difference of the INPUT0 and INPUT1 and put +# the results in OUTPUT0 and OUTPUT1 respectively. For more information +# regarding how this model.py was written, please refer to Python Backend. +class TritonPythonModel: + + def initialize(self, args): + self.model_config = model_config = json.loads(args['model_config']) + + output0_config = pb_utils.get_output_config_by_name( + model_config, "OUTPUT0") + + output1_config = pb_utils.get_output_config_by_name( + model_config, "OUTPUT1") + + self.output0_dtype = pb_utils.triton_string_to_numpy( + output0_config['data_type']) + self.output1_dtype = pb_utils.triton_string_to_numpy( + output1_config['data_type']) + + def execute(self, requests): + output0_dtype = self.output0_dtype + output1_dtype = self.output1_dtype + + responses = [] + + for request in requests: + in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT0") + in_1 = pb_utils.get_input_tensor_by_name(request, "INPUT1") + + out_0, out_1 = (in_0.as_numpy() + in_1.as_numpy(), + in_0.as_numpy() - in_1.as_numpy()) + + out_tensor_0 = pb_utils.Tensor("OUTPUT0", + out_0.astype(output0_dtype)) + out_tensor_1 = pb_utils.Tensor("OUTPUT1", + out_1.astype(output1_dtype)) + + inference_response = pb_utils.InferenceResponse( + output_tensors=[out_tensor_0, out_tensor_1]) + responses.append(inference_response) + + return responses + + def finalize(self): + print('Cleaning up...') diff --git a/3rdparty/backend-r22.12/examples/model_repos/bls_models/addsub_python/config.pbtxt b/3rdparty/backend-r22.12/examples/model_repos/bls_models/addsub_python/config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a0025a0ed1ce985467709b814ce2ba8a38bc4829 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/model_repos/bls_models/addsub_python/config.pbtxt @@ -0,0 +1,58 @@ +# Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "addsub_python" +backend: "python" +max_batch_size: 0 + +input [ + { + name: "INPUT0" + data_type: TYPE_FP32 + dims: [ 16 ] + } +] +input [ + { + name: "INPUT1" + data_type: TYPE_FP32 + dims: [ 16 ] + } +] +output [ + { + name: "OUTPUT0" + data_type: TYPE_FP32 + dims: [ 16 ] + } +] +output [ + { + name: "OUTPUT1" + data_type: TYPE_FP32 + dims: [ 16 ] + } +] diff --git a/3rdparty/backend-r22.12/examples/model_repos/bls_models/addsub_tf/1/model.savedmodel/saved_model.pb b/3rdparty/backend-r22.12/examples/model_repos/bls_models/addsub_tf/1/model.savedmodel/saved_model.pb new file mode 100644 index 0000000000000000000000000000000000000000..7a7cc038c720d6293b9ee48f2b5a68cc1bc6a7a7 Binary files /dev/null and b/3rdparty/backend-r22.12/examples/model_repos/bls_models/addsub_tf/1/model.savedmodel/saved_model.pb differ diff --git a/3rdparty/backend-r22.12/examples/model_repos/bls_models/addsub_tf/config.pbtxt b/3rdparty/backend-r22.12/examples/model_repos/bls_models/addsub_tf/config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ec176a0bd6e873f3c7f41da761576fc30c16ba3e --- /dev/null +++ b/3rdparty/backend-r22.12/examples/model_repos/bls_models/addsub_tf/config.pbtxt @@ -0,0 +1,28 @@ +name: "addsub_tf" +platform: "tensorflow_savedmodel" +max_batch_size: 0 + +input [ + { + name: "INPUT0" + data_type: TYPE_FP32 + dims: [ 16 ] + }, + { + name: "INPUT1" + data_type: TYPE_FP32 + dims: [ 16 ] + } +] +output [ + { + name: "OUTPUT0" + data_type: TYPE_FP32 + dims: [ 16 ] + }, + { + name: "OUTPUT1" + data_type: TYPE_FP32 + dims: [ 16 ] + } +] diff --git a/3rdparty/backend-r22.12/examples/model_repos/bls_models/bls_fp32/config.pbtxt b/3rdparty/backend-r22.12/examples/model_repos/bls_models/bls_fp32/config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f8c6c067bc4e75369da107e571b0282e1eb73d59 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/model_repos/bls_models/bls_fp32/config.pbtxt @@ -0,0 +1,63 @@ +# Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "bls_fp32" +backend: "bls" +max_batch_size: 0 + +input [ + { + name: "INPUT0" + data_type: TYPE_FP32 + dims: [ 16 ] + } +] +input [ + { + name: "INPUT1" + data_type: TYPE_FP32 + dims: [ 16 ] + } +] +output [ + { + name: "OUTPUT0" + data_type: TYPE_FP32 + dims: [ 16 ] + } +] +output [ + { + name: "OUTPUT1" + data_type: TYPE_FP32 + dims: [ 16 ] + } +] +instance_group [ + { + kind: KIND_CPU + } +] diff --git a/3rdparty/backend-r22.12/examples/model_repos/minimal_models/batching/1/.gitkeep b/3rdparty/backend-r22.12/examples/model_repos/minimal_models/batching/1/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/3rdparty/backend-r22.12/examples/model_repos/minimal_models/batching/config.pbtxt b/3rdparty/backend-r22.12/examples/model_repos/minimal_models/batching/config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f7423fb542c73fbc090ea99fa5bdc58cb03a664c --- /dev/null +++ b/3rdparty/backend-r22.12/examples/model_repos/minimal_models/batching/config.pbtxt @@ -0,0 +1,24 @@ +backend: "minimal" +max_batch_size: 8 +dynamic_batching { + max_queue_delay_microseconds: 5000000 +} +input [ + { + name: "IN0" + data_type: TYPE_INT32 + dims: [ 4 ] + } +] +output [ + { + name: "OUT0" + data_type: TYPE_INT32 + dims: [ 4 ] + } +] +instance_group [ + { + kind: KIND_CPU + } +] diff --git a/3rdparty/backend-r22.12/examples/model_repos/minimal_models/nonbatching/1/.gitkeep b/3rdparty/backend-r22.12/examples/model_repos/minimal_models/nonbatching/1/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/3rdparty/backend-r22.12/examples/model_repos/minimal_models/nonbatching/config.pbtxt b/3rdparty/backend-r22.12/examples/model_repos/minimal_models/nonbatching/config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..01d75a9785e00e8c1d1b18d15b57c49dd908ec93 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/model_repos/minimal_models/nonbatching/config.pbtxt @@ -0,0 +1,21 @@ +backend: "minimal" +max_batch_size: 0 +input [ + { + name: "IN0" + data_type: TYPE_INT32 + dims: [ 4 ] + } +] +output [ + { + name: "OUT0" + data_type: TYPE_INT32 + dims: [ 4 ] + } +] +instance_group [ + { + kind: KIND_CPU + } +] diff --git a/3rdparty/backend-r22.12/examples/model_repos/recommended_models/batching/1/.gitkeep b/3rdparty/backend-r22.12/examples/model_repos/recommended_models/batching/1/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/3rdparty/backend-r22.12/examples/model_repos/recommended_models/batching/config.pbtxt b/3rdparty/backend-r22.12/examples/model_repos/recommended_models/batching/config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..917ebf27ff133d5466d9f7c7618b1839987391a0 --- /dev/null +++ b/3rdparty/backend-r22.12/examples/model_repos/recommended_models/batching/config.pbtxt @@ -0,0 +1,24 @@ +backend: "recommended" +max_batch_size: 8 +dynamic_batching { + max_queue_delay_microseconds: 5000000 +} +input [ + { + name: "INPUT" + data_type: TYPE_FP32 + dims: [ 4, 4 ] + } +] +output [ + { + name: "OUTPUT" + data_type: TYPE_FP32 + dims: [ 4, 4 ] + } +] +instance_group [ + { + kind: KIND_CPU + } +] diff --git a/3rdparty/backend-r22.12/include/triton/backend/backend_common.h b/3rdparty/backend-r22.12/include/triton/backend/backend_common.h new file mode 100644 index 0000000000000000000000000000000000000000..aad3a5a4db48b6cb81700adacccb377375082ce6 --- /dev/null +++ b/3rdparty/backend-r22.12/include/triton/backend/backend_common.h @@ -0,0 +1,672 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "triton/common/error.h" +#include "triton/core/tritonbackend.h" + +#define TRITONJSON_STATUSTYPE TRITONSERVER_Error* +#define TRITONJSON_STATUSRETURN(M) \ + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, (M).c_str()) +#define TRITONJSON_STATUSSUCCESS nullptr +#include "triton/common/triton_json.h" + +#ifdef TRITON_ENABLE_GPU +#include +#endif // TRITON_ENABLE_GPU + +namespace triton { namespace backend { + +#define IGNORE_ERROR(X) \ + do { \ + TRITONSERVER_Error* ie_err__ = (X); \ + if (ie_err__ != nullptr) { \ + TRITONSERVER_ErrorDelete(ie_err__); \ + } \ + } while (false) + +#define LOG_IF_ERROR(X, MSG) \ + do { \ + TRITONSERVER_Error* lie_err__ = (X); \ + if (lie_err__ != nullptr) { \ + IGNORE_ERROR(TRITONSERVER_LogMessage( \ + TRITONSERVER_LOG_INFO, __FILE__, __LINE__, \ + (std::string(MSG) + ": " + TRITONSERVER_ErrorCodeString(lie_err__) + \ + " - " + TRITONSERVER_ErrorMessage(lie_err__)) \ + .c_str())); \ + TRITONSERVER_ErrorDelete(lie_err__); \ + } \ + } while (false) + +#define LOG_MESSAGE(LEVEL, MSG) \ + do { \ + LOG_IF_ERROR( \ + TRITONSERVER_LogMessage(LEVEL, __FILE__, __LINE__, MSG), \ + ("failed to log message: ")); \ + } while (false) + + +#define RETURN_ERROR_IF_FALSE(P, C, MSG) \ + do { \ + if (!(P)) { \ + return TRITONSERVER_ErrorNew(C, (MSG).c_str()); \ + } \ + } while (false) + +#define RETURN_ERROR_IF_TRUE(P, C, MSG) \ + do { \ + if ((P)) { \ + return TRITONSERVER_ErrorNew(C, (MSG).c_str()); \ + } \ + } while (false) + +#define RETURN_IF_ERROR(X) \ + do { \ + TRITONSERVER_Error* rie_err__ = (X); \ + if (rie_err__ != nullptr) { \ + return rie_err__; \ + } \ + } while (false) + +#ifdef TRITON_ENABLE_GPU +#define LOG_IF_CUDA_ERROR(X, MSG) \ + do { \ + cudaError_t lice_err__ = (X); \ + if (lice_err__ != cudaSuccess) { \ + IGNORE_ERROR(TRITONSERVER_LogMessage( \ + TRITONSERVER_LOG_INFO, __FILE__, __LINE__, \ + (std::string(MSG) + ": " + cudaGetErrorString(lice_err__)) \ + .c_str())); \ + } \ + } while (false) + +#define RETURN_IF_CUDA_ERROR(X, C, MSG) \ + do { \ + cudaError_t rice_err__ = (X); \ + if (rice_err__ != cudaSuccess) { \ + return TRITONSERVER_ErrorNew( \ + C, ((MSG) + ": " + cudaGetErrorString(rice_err__)).c_str()); \ + } \ + } while (false) +#endif // TRITON_ENABLE_GPU + +#define RESPOND_AND_SET_NULL_IF_ERROR(RESPONSE_PTR, X) \ + do { \ + TRITONSERVER_Error* rarie_err__ = (X); \ + if (rarie_err__ != nullptr) { \ + if (*RESPONSE_PTR != nullptr) { \ + LOG_IF_ERROR( \ + TRITONBACKEND_ResponseSend( \ + *RESPONSE_PTR, TRITONSERVER_RESPONSE_COMPLETE_FINAL, \ + rarie_err__), \ + "failed to send error response"); \ + *RESPONSE_PTR = nullptr; \ + } \ + TRITONSERVER_ErrorDelete(rarie_err__); \ + } \ + } while (false) + +#define RESPOND_ALL_AND_SET_NULL_IF_ERROR(RESPONSES, RESPONSES_COUNT, X) \ + do { \ + TRITONSERVER_Error* raasnie_err__ = (X); \ + if (raasnie_err__ != nullptr) { \ + for (size_t ridx = 0; ridx < RESPONSES_COUNT; ++ridx) { \ + if (RESPONSES[ridx] != nullptr) { \ + LOG_IF_ERROR( \ + TRITONBACKEND_ResponseSend( \ + RESPONSES[ridx], TRITONSERVER_RESPONSE_COMPLETE_FINAL, \ + raasnie_err__), \ + "failed to send error response"); \ + RESPONSES[ridx] = nullptr; \ + } \ + } \ + TRITONSERVER_ErrorDelete(raasnie_err__); \ + } \ + } while (false) + +#define RESPOND_ALL_AND_SET_TRUE_IF_ERROR(RESPONSES, RESPONSES_COUNT, BOOL, X) \ + do { \ + TRITONSERVER_Error* raasnie_err__ = (X); \ + if (raasnie_err__ != nullptr) { \ + BOOL = true; \ + for (size_t ridx = 0; ridx < RESPONSES_COUNT; ++ridx) { \ + if (RESPONSES[ridx] != nullptr) { \ + LOG_IF_ERROR( \ + TRITONBACKEND_ResponseSend( \ + RESPONSES[ridx], TRITONSERVER_RESPONSE_COMPLETE_FINAL, \ + raasnie_err__), \ + "failed to send error response"); \ + RESPONSES[ridx] = nullptr; \ + } \ + } \ + TRITONSERVER_ErrorDelete(raasnie_err__); \ + } \ + } while (false) + +#ifdef TRITON_ENABLE_STATS +#define TIMESPEC_TO_NANOS(TS) ((TS).tv_sec * 1000000000 + (TS).tv_nsec) +#define SET_TIMESTAMP(TS_NS) \ + { \ + TS_NS = std::chrono::duration_cast( \ + std::chrono::steady_clock::now().time_since_epoch()) \ + .count(); \ + } +#define DECL_TIMESTAMP(TS_NS) \ + uint64_t TS_NS; \ + SET_TIMESTAMP(TS_NS); +#else +#define DECL_TIMESTAMP(TS_NS) +#define SET_TIMESTAMP(TS_NS) +#endif // TRITON_ENABLE_STATS + +#ifndef TRITON_ENABLE_GPU +using cudaStream_t = void*; +#endif // !TRITON_ENABLE_GPU + +/// Convenience deleter for TRITONBACKEND_ResponseFactory. +struct ResponseFactoryDeleter { + void operator()(TRITONBACKEND_ResponseFactory* f) + { + LOG_IF_ERROR( + TRITONBACKEND_ResponseFactoryDelete(f), + "failed deleting response factory"); + } +}; + +// A representation of the BatchInput message in model config +class BatchInput { + public: + enum class Kind { + BATCH_ELEMENT_COUNT, + BATCH_ACCUMULATED_ELEMENT_COUNT, + BATCH_ACCUMULATED_ELEMENT_COUNT_WITH_ZERO, + BATCH_MAX_ELEMENT_COUNT_AS_SHAPE, + BATCH_ITEM_SHAPE, + BATCH_ITEM_SHAPE_FLATTEN + }; + static TRITONSERVER_Error* ParseFromModelConfig( + triton::common::TritonJson::Value& config, + std::vector* batch_inputs); + const std::vector& TargetNames() const { return target_names_; } + TRITONSERVER_DataType DataType() const { return data_type_; } + Kind BatchInputKind() const { return kind_; } + std::string BatchInputKindString() const { return kind_str_; } + const std::vector& SourceInputs() const + { + return source_inputs_; + } + + private: + TRITONSERVER_Error* Init(triton::common::TritonJson::Value& bi_config); + Kind kind_; + std::string kind_str_; + std::vector target_names_; + TRITONSERVER_DataType data_type_; + std::vector source_inputs_; +}; + +// A representation of the BatchOutput message in model config +class BatchOutput { + public: + enum class Kind { BATCH_SCATTER_WITH_INPUT_SHAPE }; + static TRITONSERVER_Error* ParseFromModelConfig( + triton::common::TritonJson::Value& config, + std::vector* batch_outputs); + const std::vector& TargetNames() const { return target_names_; } + TRITONSERVER_DataType DataType() const { return data_type_; } + const std::vector& OutputShape() const { return shape_; } + Kind BatchOutputKind() const { return kind_; } + const std::vector& SourceInputs() const + { + return source_inputs_; + } + + private: + Kind kind_; + std::vector target_names_; + TRITONSERVER_DataType data_type_; + std::vector shape_; + std::vector source_inputs_; +}; + +struct CopyParams { + CopyParams(void* dst, const void* src, const size_t byte_size) + : dst_(dst), src_(src), byte_size_(byte_size) + { + } + + void* dst_; + const void* src_; + const size_t byte_size_; +}; + +/// The value for a dimension in a shape that indicates that that +/// dimension can take on any size. +constexpr int WILDCARD_DIM = -1; + +constexpr char kTensorRTExecutionAccelerator[] = "tensorrt"; +constexpr char kOpenVINOExecutionAccelerator[] = "openvino"; +constexpr char kGPUIOExecutionAccelerator[] = "gpu_io"; +constexpr char kAutoMixedPrecisionExecutionAccelerator[] = + "auto_mixed_precision"; + +TRITONSERVER_MemoryType GetUsePinnedMemoryType( + TRITONSERVER_MemoryType ref_buffer_type); + +TRITONSERVER_Error* CommonErrorToTritonError(triton::common::Error error); + +TRITONSERVER_Error_Code StatusCodeToTritonCode( + triton::common::Error::Code error_code); + +/// Parse an array in a JSON object into the corresponding shape. The +/// array must be composed of integers. +/// +/// \param io The JSON object containing the member array. +/// \param name The name of the array member in the JSON object. +/// \param shape Returns the shape. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* ParseShape( + common::TritonJson::Value& io, const std::string& name, + std::vector* shape); + +/// Return the string representation of a shape. +/// +/// \param dims The shape dimensions. +/// \param dims_count The number of dimensions. +/// \return The string representation. +std::string ShapeToString(const int64_t* dims, const size_t dims_count); + +/// Return the string representation of a shape. +/// +/// \param shape The shape as a vector of dimensions. +/// \return The string representation. +std::string ShapeToString(const std::vector& shape); + +/// Return the number of elements of a shape. +/// +/// \param dims The shape dimensions. +/// \param dims_count The number of dimensions. +/// \return The number of elements. +int64_t GetElementCount(const int64_t* dims, const size_t dims_count); + +/// Return the number of elements of a shape. +/// +/// \param shape The shape as a vector of dimensions. +/// \return The number of elements. +int64_t GetElementCount(const std::vector& shape); + +/// Get the size, in bytes, of a tensor based on datatype and +/// shape. +/// \param dtype The data-type. +/// \param dims The shape. +/// \return The size, in bytes, of the corresponding tensor, or -1 if +/// unable to determine the size. +int64_t GetByteSize( + const TRITONSERVER_DataType& dtype, const std::vector& dims); + +/// Get an input tensor's contents into a buffer. This overload expects +/// both 'buffer' and buffers of the input to be in CPU. +/// +/// \param request The inference request. +/// \param input_name The name of the input buffer. +/// \param buffer The buffer where the input tensor content is copied into. +/// \param buffer_byte_size Acts as both input and output. On input +/// gives the size of 'buffer', in bytes. The function will fail if +/// the buffer is not large enough to hold the input tensor +/// contents. Returns the size of the input tensor data returned in +/// 'buffer'. +/// \param host_policy_name The host policy name to look up the input buffer. +/// Default input buffer will be used if nullptr is provided. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* ReadInputTensor( + TRITONBACKEND_Request* request, const std::string& input_name, char* buffer, + size_t* buffer_byte_size, const char* host_policy_name = nullptr); + +/// Get an input tensor's contents into a buffer. This overload of +/// 'ReadInputTensor' supports input buffers that can be in any memory. +/// +/// \param request The inference request. +/// \param input_name The name of the input buffer. +/// \param buffer The buffer where the input tensor content is copied into. +/// \param buffer_byte_size Acts as both input and output. On input +/// gives the size of 'buffer', in bytes. The function will fail if +/// the buffer is not large enough to hold the input tensor +/// contents. Returns the size of the input tensor data returned in +/// 'buffer'. +/// \param host_policy_name The host policy name to look up the input buffer. +/// Default input buffer will be used if nullptr is provided. +/// \param memory_type The memory type of the buffer provided. +/// \param memory_type_id The memory type id of the buffer provided. +/// \param cuda_stream specifies the stream to be associated with, and 0 can be +/// passed for default stream. +/// \param cuda_used returns whether a CUDA memory copy is initiated. If true, +/// the caller should synchronize on the given 'cuda_stream' to ensure data copy +/// is completed. +/// \param copy_on_stream whether the memory copies should be performed in cuda +/// host functions on the 'cuda_stream'. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* ReadInputTensor( + TRITONBACKEND_Request* request, const std::string& input_name, char* buffer, + size_t* buffer_byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id, cudaStream_t cuda_stream, bool* cuda_used, + const char* host_policy_name = nullptr, const bool copy_on_stream = false); + +/// Validate that an input matches one of the allowed input names. +/// \param io The model input. +/// \param allowed The set of allowed input names. +/// \return The error status. A non-OK status indicates the input +/// is not valid. +TRITONSERVER_Error* CheckAllowedModelInput( + common::TritonJson::Value& io, const std::set& allowed); + +/// Validate that an output matches one of the allowed output names. +/// \param io The model output. +/// \param allowed The set of allowed output names. +/// \return The error status. A non-OK status indicates the output +/// is not valid. +TRITONSERVER_Error* CheckAllowedModelOutput( + common::TritonJson::Value& io, const std::set& allowed); + +/// Get the tensor name, false value, and true value for a boolean +/// sequence batcher control kind. If 'required' is true then must +/// find a tensor for the control. If 'required' is false, return +/// 'tensor_name' as empty-string if the control is not mapped to any +/// tensor. +/// +/// \param batcher The JSON object of the sequence batcher. +/// \param model_name The name of the model. +/// \param control_kind The kind of control tensor to look for. +/// \param required Whether the tensor must be specified. +/// \param tensor_name Returns the name of the tensor. +/// \param tensor_datatype Returns the data type of the tensor. +/// \param fp32_false_value Returns the float value for false if +/// the tensor type is FP32. +/// \param fp32_true_value Returns the float value for true if +/// the tensor type is FP32. +/// \param int32_false_value Returns the int value for false if +/// the tensor type is INT32. +/// \param int32_true_value Returns the int value for true if +/// the tensor type is INT32. +/// \param bool_false_value Returns the bool value for false if +/// the tensor type is BOOL. +/// \param bool_true_value Returns the bool value for true if +/// the tensor type is BOOL. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* GetBooleanSequenceControlProperties( + common::TritonJson::Value& batcher, const std::string& model_name, + const std::string& control_kind, const bool required, + std::string* tensor_name, std::string* tensor_datatype, + float* fp32_false_value, float* fp32_true_value, int32_t* int32_false_value, + int32_t* int32_true_value, bool* bool_false_value, bool* bool_true_value); + +/// Get the tensor name and datatype for a non-boolean sequence +/// batcher control kind. If 'required' is true then must find a +/// tensor for the control. If 'required' is false, return +/// 'tensor_name' as empty-string if the control is not mapped to any +/// tensor. 'tensor_datatype' returns the required datatype for the +/// control. +/// +/// \param batcher The JSON object of the sequence batcher. +/// \param model_name The name of the model. +/// \param control_kind The kind of control tensor to look for. +/// \param required Whether the tensor must be specified. +/// \param tensor_name Returns the name of the tensor. +/// \param tensor_datatype Returns the data type of the tensor. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* GetTypedSequenceControlProperties( + common::TritonJson::Value& batcher, const std::string& model_name, + const std::string& control_kind, const bool required, + std::string* tensor_name, std::string* tensor_datatype); + +/// Create and send an error response for a set of requests. This +/// function takes ownership of 'response_err' and so the caller must +/// not access or delete it after this call returns. +/// +/// \param requests The requests. +/// \param request_count The number of 'requests'. +/// \param response_err The error to send to each request. +/// \param release_request If true, the requests will be released after +/// sending the error responses and the request pointers are set to +/// nullptr. +void RequestsRespondWithError( + TRITONBACKEND_Request** requests, const uint32_t request_count, + TRITONSERVER_Error* response_err, const bool release_request = true); + +/// Send an error response for a set of responses. This function takes +/// ownership of 'response_err' and so the caller must not access or +/// delete it after this call returns. +/// +/// \param responses The responses. +/// \param response_count The number of 'responses'. +/// \param response_err The error to send. +void SendErrorForResponses( + std::vector* responses, + const uint32_t response_count, TRITONSERVER_Error* response_err); + +/// Copy buffer from 'src' to 'dst' for given 'byte_size'. The buffer location +/// is identified by the memory type and id, and the corresponding copy will be +/// initiated. +/// \param msg The message to be prepended in error message. +/// \param src_memory_type The memory type of the source buffer. +/// \param src_memory_type_id The memory type id of the source buffer. +/// \param dst_memory_type The memory type of the destination buffer. +/// \param dst_memory_type_id The memory type id of the destination buffer. +/// \param byte_size The byte size of the source buffer. +/// \param src The pointer to the source buffer. +/// \param dst The pointer to the destination buffer. +/// \param cuda_stream specifies the stream to be associated with, and 0 can be +/// passed for default stream. +/// \param cuda_used returns whether a CUDA memory copy is initiated. If true, +/// the caller should synchronize on the given 'cuda_stream' to ensure data copy +/// is completed. +/// \param copy_on_stream whether the memory copies should be performed in cuda +/// host functions on the 'cuda_stream'. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* CopyBuffer( + const std::string& msg, const TRITONSERVER_MemoryType src_memory_type, + const int64_t src_memory_type_id, + const TRITONSERVER_MemoryType dst_memory_type, + const int64_t dst_memory_type_id, const size_t byte_size, const void* src, + void* dst, cudaStream_t cuda_stream, bool* cuda_used, + const bool copy_on_stream = false); + +/// Does a file or directory exist? +/// \param path The path to check for existance. +/// \param exists Returns true if file/dir exists +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* FileExists(const std::string& path, bool* exists); + +/// Read a text file into a string. +/// \param path The path of the file. +/// \param contents Returns the contents of the file. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* ReadTextFile( + const std::string& path, std::string* contents); + +/// Is a path a directory? +/// \param path The path to check. +/// \param is_dir Returns true if path represents a directory +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* IsDirectory(const std::string& path, bool* is_dir); + +/// Join path segments into a longer path +/// \param segments The path segments. +/// \return the path formed by joining the segments. +std::string JoinPath(std::initializer_list segments); + +/// Returns the content in the model version path and the path to the content as +/// key-value pair. +/// \param model_repository_path The path to the model repository. +/// \param version The version of the model. +/// \param ignore_directories Whether the directories will be ignored. +/// \param ignore_files Whether the files will be ignored. +/// \param model_paths Returns the content in the model version path and +/// the path to the content. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* ModelPaths( + const std::string& model_repository_path, uint64_t version, + const bool ignore_directories, const bool ignore_files, + std::unordered_map* model_paths); + +/// Create a CUDA stream appropriate for GPU<->CPU data transfer +/// operations for a given GPU device. The caller takes ownership of +/// the stream. 'stream' returns nullptr if GPU support is disabled. +/// +/// \param device_id The ID of the GPU. +/// \param priority The stream priority. Use 0 for normal priority. +/// \param stream Returns the created stream. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* CreateCudaStream( + const int device_id, const int cuda_stream_priority, cudaStream_t* stream); + +/// Parse the string as long long integer. +/// +/// \param value The string. +/// \param parse_value The long long integral value of the string. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* ParseLongLongValue( + const std::string& value, int64_t* parsed_value); + +/// Parse the string as unsigned long long integer. +/// +/// \param value The string. +/// \param parse_value The unsigned long long integral value of the string. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* ParseUnsignedLongLongValue( + const std::string& value, uint64_t* parsed_value); + +/// Parse the string as boolean. +/// +/// \param value The string. +/// \param parse_value The boolean value of the string. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* ParseBoolValue( + const std::string& value, bool* parsed_value); + +/// Parse the string as integer. +/// +/// \param value The string. +/// \param parse_value The integral value of the string. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* ParseIntValue(const std::string& value, int* parsed_value); + +/// Parse the string as double. +/// +/// \param value The string. +/// \param parse_value The double value of the string. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* ParseDoubleValue( + const std::string& value, double* parsed_value); + +/// Return the value of the specified key in a JSON object. +/// +/// \param params The JSON object containing the key-value mapping. +/// \param key The key to look up the value in the JSON object. +/// \param value Returns the value. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* GetParameterValue( + triton::common::TritonJson::Value& params, const std::string& key, + std::string* value); + +/// Return the Triton server data type of the data type string specified +/// in model config JSON. +/// +/// \param data_type_str The string representation of the data type. +/// \return the Triton server data type. +TRITONSERVER_DataType ModelConfigDataTypeToTritonServerDataType( + const std::string& data_type_str); + +/// Try to parse the requested parameter. +/// +/// \param params The param in model config +/// \param mkey Key in the model config. +/// \param value The parsed string value. +/// \param default_value Default value to use when key is not found. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* TryParseModelStringParameter( + triton::common::TritonJson::Value& params, const std::string& mkey, + std::string* value, const std::string& default_value); + +/// Try to parse the requested parameter. +/// +/// \param params The param in model config +/// \param mkey Key in the model config. +/// \param value The parsed int value. +/// \param default_value Default value to use when key is not found. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* TryParseModelStringParameter( + triton::common::TritonJson::Value& params, const std::string& mkey, + int* value, const int& default_value); + +/// Try to parse the requested parameter. +/// +/// \param params The param in model config +/// \param mkey Key in the model config. +/// \param value The parsed bool value. +/// \param default_value Default value to use when key is not found. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* TryParseModelStringParameter( + triton::common::TritonJson::Value& params, const std::string& mkey, + bool* value, const bool& default_value); + +/// Try to parse the requested parameter. +/// +/// \param params The param in model config +/// \param mkey Key in the model config. +/// \param value The parsed uint64 value. +/// \param default_value Default value to use when key is not found. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* TryParseModelStringParameter( + triton::common::TritonJson::Value& params, const std::string& mkey, + uint64_t* value, const uint64_t& default_value); + +/// Get a string representation of a tensor buffer. +/// +/// \param str Returns the string. +/// \param buffer The base pointer to the tensor buffer. +/// \param buffer_byte_size The size of the buffer in bytes. +/// \param datatype The type of the tensor +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_Error* BufferAsTypedString( + std::string& str, const char* buffer, size_t buffer_byte_size, + TRITONSERVER_DataType datatype); + +/// Get the ID of the request as a string formatted for logging. +/// +/// \param request Request of which to get the ID. +/// \return a formatted string for logging the request ID. +std::string GetRequestId(TRITONBACKEND_Request* request); + +}} // namespace triton::backend diff --git a/3rdparty/backend-r22.12/include/triton/backend/backend_input_collector.h b/3rdparty/backend-r22.12/include/triton/backend/backend_input_collector.h new file mode 100644 index 0000000000000000000000000000000000000000..44a7b1bc625db649f44f21523a02c79474a54436 --- /dev/null +++ b/3rdparty/backend-r22.12/include/triton/backend/backend_input_collector.h @@ -0,0 +1,301 @@ +// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include +#include "triton/backend/backend_common.h" +#include "triton/backend/backend_memory.h" +#include "triton/common/async_work_queue.h" +#include "triton/common/sync_queue.h" +#include "triton/core/tritonbackend.h" + +#ifdef TRITON_ENABLE_GPU +#include +#endif // TRITON_ENABLE_GPU + +namespace triton { namespace backend { + +#ifndef TRITON_ENABLE_GPU +using cudaStream_t = void*; +using cudaEvent_t = void*; +#endif // !TRITON_ENABLE_GPU + +// +// BackendInputCollector +// +class BackendInputCollector { + public: + // The caller can optionally provide 'event' for internal synchronization + // instead of using 'stream'. If 'host_policy_name' is provided, it must be + // valid for the lifetime of the collector + explicit BackendInputCollector( + TRITONBACKEND_Request** requests, const uint32_t request_count, + std::vector* responses, + TRITONBACKEND_MemoryManager* memory_manager, const bool pinned_enabled, + cudaStream_t stream, cudaEvent_t event = nullptr, + cudaEvent_t buffer_ready_event = nullptr, + const size_t kernel_buffer_threshold = 0, + const char* host_policy_name = nullptr, const bool copy_on_stream = false, + const bool coalesce_request_input = false) + : need_sync_(false), requests_(requests), request_count_(request_count), + responses_(responses), memory_manager_(memory_manager), + pinned_enabled_(pinned_enabled), + use_async_cpu_copy_(triton::common::AsyncWorkQueue::WorkerCount() > 1), + stream_(stream), event_(event), buffer_ready_event_(buffer_ready_event), + kernel_buffer_threshold_(kernel_buffer_threshold), + pending_pinned_byte_size_(0), pending_pinned_offset_(0), + pending_copy_kernel_buffer_byte_size_(0), + pending_copy_kernel_buffer_offset_(0), + pending_copy_kernel_input_buffer_counts_(0), async_task_count_(0), + host_policy_cstr_(host_policy_name), copy_on_stream_(copy_on_stream), + coalesce_request_input_(coalesce_request_input) + { + } + + ~BackendInputCollector() = default; + + // Process all requests for a named input tensor and return the + // concatenated values of those requests in a single contiguous + // buffer. This overload of the function can avoid data copy if the + // tensor values are already contiguous and the caller doesn't + // provide a destination 'buffer'. + // + // 'buffer' is used to determine whether the input should be placed at the + // 'buffer' provided by the caller. If 'buffer' == nullptr, the returned + // buffer will be managed by the BackendInputCollector object and + // has the same lifecycle as the BackendInputCollector object. + // 'buffer_byte_size' is the byte size of 'buffer' if it is not nullptr. + // 'allowed_input_types' is the ordered list of the memory type and id pairs + // that the returned buffer can be. It must only contain the memory type + // and id of 'buffer' if 'buffer' is not nullptr. + // 'dst_buffer' returns the contiguous buffer of the input tensor. + // 'dst_buffer_byte_size' the byte size of 'dst_buffer'. + // 'dst_memory_type' returns the memory type of 'dst_buffer'. + // 'dst_memory_type_id' returns the memory type id of 'dst_buffer'. + TRITONSERVER_Error* ProcessTensor( + const char* input_name, char* buffer, const size_t buffer_byte_size, + const std::vector>& + allowed_input_types, + const char** dst_buffer, size_t* dst_buffer_byte_size, + TRITONSERVER_MemoryType* dst_memory_type, int64_t* dst_memory_type_id); + + // Process all requests for a named input tensor and return the + // concatenated values of those requests in a single contiguous + // 'buffer'. + // + // 'buffer' The buffer to hold the concatenates tensor value. Must + // be large enough to hold all tensor value. + // 'buffer_byte_size' is the byte size of 'buffer'. + // 'dst_memory_type' The memory type of 'buffer'. + // 'dst_memory_type_id' The memory type id of 'buffer'. + void ProcessTensor( + const char* input_name, char* buffer, const size_t buffer_byte_size, + const TRITONSERVER_MemoryType memory_type, const int64_t memory_type_id); + + // Process the batch input and return its shape. Returning error indicates + // that the batch input can't be formed properly and the caller should abort + // the whole batch. + TRITONSERVER_Error* BatchInputShape( + const BatchInput& batch_input, std::vector* shape); + + // Process the batch input and derive its value into 'buffer'. Returning + // error indicates that the batch input can't be formed properly and + // the caller should abort the whole batch. + // 'buffer' is used to determine whether the input should be placed at the + // 'buffer' provided by the caller. If 'buffer' == nullptr, the returned + // buffer will be managed by the BackendInputCollector object and + // has the same lifecycle as the BackendInputCollector object. + // 'buffer_byte_size' is the byte size of 'buffer' if it is not nullptr. + // 'allowed_input_types' is the ordered list of the memory type and id pairs + // that the returned buffer can be. It must only contain the memory type + // and id of 'buffer' if it is not nullptr. + // 'dst_buffer' returns the contiguous buffer of the input tensor. + // 'dst_memory_type' returns the memory type of 'dst_buffer'. + // 'dst_memory_type_id' returns the memory type id of 'dst_buffer'. + TRITONSERVER_Error* ProcessBatchInput( + const BatchInput& batch_input, char* buffer, + const size_t buffer_byte_size, + const std::vector>& + allowed_input_types, + const char** dst_buffer, size_t* dst_buffer_byte_size, + TRITONSERVER_MemoryType* dst_memory_type, int64_t* dst_memory_type_id); + + // Finalize processing of all requests for all input tensors. Return + // true if cudaMemcpyAsync is called, and the caller should call + // cudaStreamSynchronize (or cudaEventSynchronize on 'event') before + // using the data. + bool Finalize(); + + private: + struct ContiguousBuffer { + ContiguousBuffer() : start_request_idx_(0), end_request_idx_(0) {} + MemoryDesc memory_desc_; + size_t start_request_idx_; + size_t end_request_idx_; + }; + + class InputIterator { + public: + InputIterator( + TRITONBACKEND_Request** requests, const uint32_t request_count, + std::vector* responses, const char* input_name, + const char* host_policy_name, const bool coalesce_request_input); + + // Return false if iterator reaches the end of inputs, 'input' is not set. + bool GetNextContiguousInput(ContiguousBuffer* input); + + private: + TRITONBACKEND_Request** requests_; + const uint32_t request_count_; + std::vector* responses_; + const char* input_name_; + const char* host_policy_; + const bool coalesce_request_input_; + + TRITONBACKEND_Input* curr_input_; + size_t curr_request_idx_; + size_t curr_buffer_idx_; + uint32_t curr_buffer_cnt_; + bool reach_end_; + }; + + // Return whether the entire input is in a contiguous buffer. If returns true, + // the properties of the contiguous input buffer will also be returned. + // Otherwise, only 'buffer_byte_size' will be set and return the total byte + // size of the input. + bool GetInputBufferIfContiguous( + const char* input_name, const char** buffer, size_t* buffer_byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id); + bool FlushPendingPinned( + char* tensor_buffer, const size_t tensor_buffer_byte_size, + const TRITONSERVER_MemoryType tensor_memory_type, + const int64_t tensor_memory_type_id); + bool FlushPendingCopyKernel( + char* tensor_buffer, const size_t tensor_buffer_byte_size, + const TRITONSERVER_MemoryType tensor_memory_type, + const int64_t tensor_memory_type_id); + TRITONSERVER_Error* LaunchCopyKernel( + char* tensor_buffer, const size_t tensor_buffer_byte_size, + const TRITONSERVER_MemoryType tensor_memory_type, + const int64_t tensor_memory_type_id); + bool SetInputTensor( + const char* input_name, const ContiguousBuffer& input, + char* tensor_buffer, const size_t tensor_buffer_byte_size, + const TRITONSERVER_MemoryType tensor_memory_type, + const int64_t tensor_memory_type_id, const size_t tensor_buffer_offset, + const TRITONSERVER_MemoryType use_pinned_memory_type, + const bool use_kernel, const bool wait_buffer); + template + TRITONSERVER_Error* SetElementCount( + const std::string& source_input, char* buffer, + const size_t buffer_byte_size); + template + TRITONSERVER_Error* SetAccumulatedElementCount( + const std::string& source_input, char* buffer, + const size_t buffer_byte_size); + template + TRITONSERVER_Error* SetBatchItemShape( + const std::string& source_input, char* buffer, + const size_t buffer_byte_size); + + bool need_sync_; + TRITONBACKEND_Request** requests_; + const uint32_t request_count_; + std::vector* responses_; + TRITONBACKEND_MemoryManager* memory_manager_; + const bool pinned_enabled_; + const bool use_async_cpu_copy_; + cudaStream_t stream_; + cudaEvent_t event_; + cudaEvent_t buffer_ready_event_; + const size_t kernel_buffer_threshold_; + + size_t pending_pinned_byte_size_; + size_t pending_pinned_offset_; + std::list pending_pinned_input_buffers_; + + // managed memories that need to live over the lifetime of this + // BackendInputCollector object. + std::list> in_use_memories_; + + size_t pending_copy_kernel_buffer_byte_size_; + size_t pending_copy_kernel_buffer_offset_; + size_t pending_copy_kernel_input_buffer_counts_; + std::list pending_copy_kernel_input_buffers_; + std::vector>> input_ptr_buffer_host_; + std::vector>> byte_size_buffer_host_; + std::vector>> + byte_size_offset_buffer_host_; + + // Pinned memory buffers and the corresponding request_inputs where + // the final copy to the tensor is deferred until Finalize() after + // waiting for all in-flight copies. + struct DeferredPinned { + DeferredPinned( + char* pinned_memory, const size_t pinned_memory_size, + char* tensor_buffer, const size_t tensor_buffer_offset, + const TRITONSERVER_MemoryType tensor_memory_type, + const int64_t tensor_memory_id, + std::list&& request_buffers, + std::vector* responses) + : finalized_(false), pinned_memory_(pinned_memory), + pinned_memory_size_(pinned_memory_size), + tensor_buffer_(tensor_buffer), + tensor_buffer_offset_(tensor_buffer_offset), + tensor_memory_type_(tensor_memory_type), + tensor_memory_id_(tensor_memory_id), + requests_(std::move(request_buffers)), responses_(responses) + { + } + + bool Finalize(cudaStream_t stream); + bool finalized_; + // Holding reference to the pinned memory buffer, which is managed + // by BackendInputCollector as 'pinned_memory' + char* pinned_memory_; + const size_t pinned_memory_size_; + char* tensor_buffer_; + const size_t tensor_buffer_offset_; + const TRITONSERVER_MemoryType tensor_memory_type_; + const int64_t tensor_memory_id_; + std::list requests_; + std::vector* responses_; + }; + + std::list deferred_pinned_; + // FIXME use future to maintain an issue-order queue to drop task count + triton::common::SyncQueue completion_queue_; + size_t async_task_count_; + + const char* host_policy_cstr_; + const bool copy_on_stream_; + const bool coalesce_request_input_; +}; + +}} // namespace triton::backend diff --git a/3rdparty/backend-r22.12/include/triton/backend/backend_memory.h b/3rdparty/backend-r22.12/include/triton/backend/backend_memory.h new file mode 100644 index 0000000000000000000000000000000000000000..819ca3743a5b2f929e4b8f429e622c1e86a47232 --- /dev/null +++ b/3rdparty/backend-r22.12/include/triton/backend/backend_memory.h @@ -0,0 +1,138 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include "triton/core/tritonbackend.h" +#include "triton/core/tritonserver.h" + +namespace triton { namespace backend { + +// Colletion of common properties that describes a buffer in Triton +struct MemoryDesc { + MemoryDesc() + : buffer_(nullptr), byte_size_(0), memory_type_(TRITONSERVER_MEMORY_CPU), + memory_type_id_(0) + { + } + MemoryDesc( + const char* buffer, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id) + : buffer_(buffer), byte_size_(byte_size), memory_type_(memory_type), + memory_type_id_(memory_type_id) + { + } + const char* buffer_; + size_t byte_size_; + TRITONSERVER_MemoryType memory_type_; + int64_t memory_type_id_; +}; + +// +// BackendMemory +// +// Utility class for allocating and deallocating memory using both +// TRITONBACKEND_MemoryManager and direct GPU and CPU malloc/free. +// +class BackendMemory { + public: + enum class AllocationType { CPU, CPU_PINNED, GPU, CPU_PINNED_POOL, GPU_POOL }; + + // Allocate a contiguous block of 'alloc_type' memory. 'mem' + // returns the pointer to the allocated memory. + // + // CPU, CPU_PINNED_POOL and GPU_POOL are allocated using + // TRITONBACKEND_MemoryManagerAllocate. Note that CPU_PINNED and GPU + // allocations can be much slower than the POOL variants. + // + // Two error codes have specific interpretations for this function: + // + // TRITONSERVER_ERROR_UNSUPPORTED: Indicates that function is + // incapable of allocating the requested memory type and memory + // type ID. Requests for the memory type and ID will always fail + // no matter 'byte_size' of the request. + // + // TRITONSERVER_ERROR_UNAVAILABLE: Indicates that function can + // allocate the memory type and ID but that currently it cannot + // allocate a contiguous block of memory of the requested + // 'byte_size'. + static TRITONSERVER_Error* Create( + TRITONBACKEND_MemoryManager* manager, const AllocationType alloc_type, + const int64_t memory_type_id, const size_t byte_size, + BackendMemory** mem); + + // Allocate a contiguous block of memory by attempting the + // allocation using 'alloc_types' in order until one is successful. + // See BackendMemory::Create() above for details. + static TRITONSERVER_Error* Create( + TRITONBACKEND_MemoryManager* manager, + const std::vector& alloc_types, + const int64_t memory_type_id, const size_t byte_size, + BackendMemory** mem); + + // Creates a BackendMemory object from a pre-allocated buffer. The buffer + // is not owned by the object created with this function. Hence, for + // proper operation, the lifetime of the buffer should atleast extend till + // the corresponding BackendMemory. + static TRITONSERVER_Error* Create( + TRITONBACKEND_MemoryManager* manager, const AllocationType alloc_type, + const int64_t memory_type_id, void* buffer, const size_t byte_size, + BackendMemory** mem); + + ~BackendMemory(); + + AllocationType AllocType() const { return alloctype_; } + int64_t MemoryTypeId() const { return memtype_id_; } + char* MemoryPtr() { return buffer_; } + size_t ByteSize() const { return byte_size_; } + TRITONSERVER_MemoryType MemoryType() const + { + return AllocTypeToMemoryType(alloctype_); + } + + static TRITONSERVER_MemoryType AllocTypeToMemoryType(const AllocationType a); + static const char* AllocTypeString(const AllocationType a); + + private: + BackendMemory( + TRITONBACKEND_MemoryManager* manager, const AllocationType alloctype, + const int64_t memtype_id, char* buffer, const size_t byte_size, + const bool owns_buffer = true) + : manager_(manager), alloctype_(alloctype), memtype_id_(memtype_id), + buffer_(buffer), byte_size_(byte_size), owns_buffer_(owns_buffer) + { + } + + TRITONBACKEND_MemoryManager* manager_; + AllocationType alloctype_; + int64_t memtype_id_; + char* buffer_; + size_t byte_size_; + bool owns_buffer_; +}; + +}} // namespace triton::backend diff --git a/3rdparty/backend-r22.12/include/triton/backend/backend_model.h b/3rdparty/backend-r22.12/include/triton/backend/backend_model.h new file mode 100644 index 0000000000000000000000000000000000000000..3179c6e8e656c8c9f618d297fdbd6a61b8c1fac0 --- /dev/null +++ b/3rdparty/backend-r22.12/include/triton/backend/backend_model.h @@ -0,0 +1,146 @@ +// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include "triton/backend/backend_common.h" +#include "triton/core/tritonbackend.h" +#include "triton/core/tritonserver.h" + +namespace triton { namespace backend { + +// +// BackendModel +// +// Common functionality for a backend model. This class is provided as +// a convenience; backends are not required to use this class. +// +class BackendModel { + public: + BackendModel( + TRITONBACKEND_Model* triton_model, const bool allow_optional = false); + virtual ~BackendModel() = default; + + // Get the handle to the TRITONBACKEND server hosting this model. + TRITONSERVER_Server* TritonServer() { return triton_server_; } + + // Get the handle to the memory manager for this model. + TRITONBACKEND_MemoryManager* TritonMemoryManager() + { + return triton_memory_manager_; + } + + // Get the handle to the TRITONBACKEND model. + TRITONBACKEND_Model* TritonModel() { return triton_model_; } + + // Get the name and version of the model. + const std::string& Name() const { return name_; } + uint64_t Version() const { return version_; } + const std::string& RepositoryPath() const { return repository_path_; } + + // The model configuration. + common::TritonJson::Value& ModelConfig() { return model_config_; } + + // Sets the updated model configuration to the core. + TRITONSERVER_Error* SetModelConfig(); + + // Parses information out of the model configuration. + TRITONSERVER_Error* ParseModelConfig(); + + // Maximum batch size supported by the model. A value of 0 + // indicates that the model does not support batching. + int MaxBatchSize() const { return max_batch_size_; } + + // Set the max batch size for the model. When a backend + // auto-completes a configuration it may set or change the maximum + // batch size. + void SetMaxBatchSize(const int b) { max_batch_size_ = b; } + + // Does this model support batching in the first dimension? + TRITONSERVER_Error* SupportsFirstDimBatching(bool* supports); + + // Use indirect pinned memory buffer when copying an input or output + // tensor to/from the model. + bool EnablePinnedInput() const { return enable_pinned_input_; } + bool EnablePinnedOutput() const { return enable_pinned_output_; } + + const std::vector& BatchInputs() const { return batch_inputs_; } + const std::vector& BatchOutputs() const + { + return batch_outputs_; + } + const BatchOutput* FindBatchOutput(const std::string& output_name) const; + bool IsInputRagged(const std::string& input_name) const + { + return (ragged_inputs_.find(input_name) != ragged_inputs_.end()); + } + bool IsInputOptional(const std::string& input_name) const + { + return (optional_inputs_.find(input_name) != optional_inputs_.end()); + } + + protected: + TRITONSERVER_Server* triton_server_; + TRITONBACKEND_MemoryManager* triton_memory_manager_; + TRITONBACKEND_Model* triton_model_; + std::string name_; + uint64_t version_; + std::string repository_path_; + bool allow_optional_; + + common::TritonJson::Value model_config_; + int max_batch_size_; + bool enable_pinned_input_; + bool enable_pinned_output_; + std::vector batch_inputs_; + std::vector batch_outputs_; + std::map batch_output_map_; + std::set ragged_inputs_; + std::set optional_inputs_; +}; + +// +// BackendModelException +// +// Exception thrown if error occurs while constructing an +// BackendModel. +// +struct BackendModelException { + BackendModelException(TRITONSERVER_Error* err) : err_(err) {} + TRITONSERVER_Error* err_; +}; + +#define THROW_IF_BACKEND_MODEL_ERROR(X) \ + do { \ + TRITONSERVER_Error* tie_err__ = (X); \ + if (tie_err__ != nullptr) { \ + throw triton::backend::BackendModelException(tie_err__); \ + } \ + } while (false) + +}} // namespace triton::backend diff --git a/3rdparty/backend-r22.12/include/triton/backend/backend_model_instance.h b/3rdparty/backend-r22.12/include/triton/backend/backend_model_instance.h new file mode 100644 index 0000000000000000000000000000000000000000..c4deeea09d760c5a3438532469eabcc570fdb004 --- /dev/null +++ b/3rdparty/backend-r22.12/include/triton/backend/backend_model_instance.h @@ -0,0 +1,118 @@ +// Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include "triton/core/tritonbackend.h" + +#ifdef TRITON_ENABLE_GPU +#include +#endif // TRITON_ENABLE_GPU + +namespace triton { namespace backend { + +#ifndef TRITON_ENABLE_GPU +using cudaStream_t = void*; +#endif // !TRITON_ENABLE_GPU + +class BackendModel; + +// +// BackendModelInstance +// +// Common functionality for a backend model instance. This class is +// provided as a convenience; backends are not required to use this +// class. +// +class BackendModelInstance { + public: + BackendModelInstance( + BackendModel* backend_model, + TRITONBACKEND_ModelInstance* triton_model_instance); + virtual ~BackendModelInstance(); + + // Get the name, kind and device ID of the instance. + const std::string& Name() const { return name_; } + TRITONSERVER_InstanceGroupKind Kind() const { return kind_; } + int32_t DeviceId() const { return device_id_; } + + // Get the handle to the TRITONBACKEND model instance. + TRITONBACKEND_ModelInstance* TritonModelInstance() + { + return triton_model_instance_; + } + + // Get the BackendModel representing the model that corresponds to + // this instance. + BackendModel* Model() const { return backend_model_; } + + // The model configuration 'default_model_filename' value, or the + // value in model configuration 'cc_model_filenames' for the GPU + // targeted by this instance. If neither are specified in the model + // configuration, the return empty string. + const std::string& ArtifactFilename() const { return artifact_filename_; } + + // Returns the stream associated with this instance that can be used + // for GPU<->CPU memory transfers. Returns nullptr if GPU support is + // disabled or if this instance is not executing on a GPU. + cudaStream_t CudaStream() { return stream_; } + + const std::string& HostPolicyName() const { return host_policy_name_; } + + protected: + BackendModel* backend_model_; + TRITONBACKEND_ModelInstance* triton_model_instance_; + + std::string name_; + TRITONSERVER_InstanceGroupKind kind_; + int32_t device_id_; + + std::string artifact_filename_; + cudaStream_t stream_; + + std::string host_policy_name_; +}; + +// +// BackendModelInstanceException +// +// Exception thrown if error occurs while constructing an +// BackendModelInstance. +// +struct BackendModelInstanceException { + BackendModelInstanceException(TRITONSERVER_Error* err) : err_(err) {} + TRITONSERVER_Error* err_; +}; + +#define THROW_IF_BACKEND_INSTANCE_ERROR(X) \ + do { \ + TRITONSERVER_Error* tie_err__ = (X); \ + if (tie_err__ != nullptr) { \ + throw triton::backend::BackendModelInstanceException(tie_err__); \ + } \ + } while (false) + +}} // namespace triton::backend diff --git a/3rdparty/backend-r22.12/include/triton/backend/backend_output_responder.h b/3rdparty/backend-r22.12/include/triton/backend/backend_output_responder.h new file mode 100644 index 0000000000000000000000000000000000000000..611e103c0a9e4dd6f526e40eb746f6b7a9b6a3d1 --- /dev/null +++ b/3rdparty/backend-r22.12/include/triton/backend/backend_output_responder.h @@ -0,0 +1,195 @@ +// Copyright 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include "triton/backend/backend_common.h" +#include "triton/common/async_work_queue.h" +#include "triton/core/tritonbackend.h" + +#ifdef TRITON_ENABLE_GPU +#include +#endif // TRITON_ENABLE_GPU + +namespace triton { namespace backend { + +#ifndef TRITON_ENABLE_GPU +using cudaStream_t = void*; +using cudaEvent_t = void*; +#endif // !TRITON_ENABLE_GPU + +// +// BackendOutputResponder +// +class BackendOutputResponder { + public: + // The caller can optionally provide 'event' for internal synchronization + // instead of using 'stream'. + explicit BackendOutputResponder( + TRITONBACKEND_Request** requests, const uint32_t request_count, + std::vector* responses, + TRITONBACKEND_MemoryManager* memory_manager, + const bool first_dim_batching, const bool pinned_enabled, + cudaStream_t stream, cudaEvent_t event = nullptr, + bool copy_on_stream = false) + : need_sync_(false), requests_(requests), request_count_(request_count), + responses_(responses), memory_manager_(memory_manager), + first_dim_batching_(first_dim_batching), + pinned_enabled_(pinned_enabled), + use_async_cpu_copy_(triton::common::AsyncWorkQueue::WorkerCount() > 1), + stream_(stream), event_(event), pending_pinned_byte_size_(0), + copy_on_stream_(copy_on_stream) + { + } + + // Legacy constructor for backwards compatibility. The above + // constructor should be used for all new cases. The responder needs + // to know if the model is batching along the first dimension. With + // this constructor we derive that information from the + // max_batch_size value instead of having it provided directly as in + // the above constructor. + explicit BackendOutputResponder( + TRITONBACKEND_Request** requests, const uint32_t request_count, + std::vector* responses, const int max_batch_size, + TRITONBACKEND_MemoryManager* memory_manager, const bool pinned_enabled, + cudaStream_t stream, cudaEvent_t event = nullptr, + bool copy_on_stream = false) + : need_sync_(false), requests_(requests), request_count_(request_count), + responses_(responses), memory_manager_(memory_manager), + first_dim_batching_(max_batch_size >= 1), + pinned_enabled_(pinned_enabled), + use_async_cpu_copy_(triton::common::AsyncWorkQueue::WorkerCount() > 1), + stream_(stream), event_(event), pending_pinned_byte_size_(0), + copy_on_stream_(copy_on_stream) + { + } + + ~BackendOutputResponder(); + + // Process all responses for a named output tensor. + // 'batchn_shape' may be modified by the call. + void ProcessTensor( + const std::string& name, const TRITONSERVER_DataType datatype, + std::vector& batchn_shape, const char* buffer, + const TRITONSERVER_MemoryType memory_type, const int64_t memory_type_id); + + // Process all responses for a named state tensor. Returns a vector of + // TRITONBACKEND_State objects that the backend can use to update the state. + // If TRITONBACKEND_StateUpdate is not called on the vector elements, the + // state will not be updated. + // 'batchn_shape' may be modified by the call. + std::vector ProcessStateTensor( + const std::string& name, const TRITONSERVER_DataType datatype, + std::vector& batchn_shape, const char* buffer, + const TRITONSERVER_MemoryType memory_type, const int64_t memory_type_id); + + // Process all responses for a batch output and derive its value from + // 'buffer'. + void ProcessBatchOutput( + const std::string& name, const BatchOutput& batch_output, + const char* buffer, const TRITONSERVER_MemoryType memory_type, + const int64_t memory_type_id); + + // Finalize processing of all responses for all output + // tensors. Return true if cudaMemcpyAsync is called, and the caller + // should call cudaStreamSynchronize (or cudaEventSynchronize on 'event') + // before using the data. + bool Finalize(); + + private: + bool FlushPendingPinned( + const char* tensor_buffer, + const TRITONSERVER_MemoryType tensor_memory_type, + const int64_t tensor_memory_type_id); + bool SetFixedSizeBuffer( + TRITONBACKEND_Response** response, void* response_state_or_output, + const std::string& output_name, const size_t tensor_byte_size, + const size_t tensor_offset, const char* tensor_buffer, + const TRITONSERVER_MemoryType tensor_memory_type, + const int64_t tensor_memory_type_id, + const TRITONSERVER_MemoryType use_pinned_memory_type, bool state); + + struct OutputData { + OutputData( + const std::string& name, void* buffer, const size_t buffer_byte_size, + const TRITONSERVER_MemoryType memory_type, const int64_t memory_type_id) + : name_(name), buffer_(buffer), buffer_byte_size_(buffer_byte_size), + memory_type_(memory_type), memory_type_id_(memory_type_id) + { + } + const std::string name_; + void* buffer_; + const size_t buffer_byte_size_; + const TRITONSERVER_MemoryType memory_type_; + const int64_t memory_type_id_; + }; + + bool need_sync_; + TRITONBACKEND_Request** requests_; + const uint32_t request_count_; + std::vector* responses_; + TRITONBACKEND_MemoryManager* memory_manager_; + const bool first_dim_batching_; + const bool pinned_enabled_; + const bool use_async_cpu_copy_; + cudaStream_t stream_; + cudaEvent_t event_; + + using ResponsesList = + std::list>; + + size_t pending_pinned_byte_size_; + size_t pending_pinned_offset_; + ResponsesList pending_pinned_outputs_; + const bool copy_on_stream_; + + // Pinned memories that need to live over the lifetime of this + // BackendOutputResponder object. + std::list pinned_memories_; + + // Pinned memory buffers and the corresponding response outputs + // where the final copy to the response is deferred until Finalize() + // after waiting for all in-flight copies. + struct DeferredPinned { + DeferredPinned( + char* pinned_memory, const size_t pinned_memory_size, + ResponsesList&& responses) + : pinned_memory_(pinned_memory), + pinned_memory_size_(pinned_memory_size), + responses_(std::move(responses)) + { + } + char* pinned_memory_; + const size_t pinned_memory_size_; + ResponsesList responses_; + }; + + std::list deferred_pinned_; +}; + +}} // namespace triton::backend diff --git a/3rdparty/backend-r22.12/src/backend_common.cc b/3rdparty/backend-r22.12/src/backend_common.cc new file mode 100644 index 0000000000000000000000000000000000000000..4f7a660b3d55a95444734857ed500ea1ebb1dbfc --- /dev/null +++ b/3rdparty/backend-r22.12/src/backend_common.cc @@ -0,0 +1,1374 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "triton/backend/backend_common.h" + +#ifdef _WIN32 +// suppress the min and max definitions in Windef.h. +#define NOMINMAX +#include + +// _CRT_INTERNAL_NONSTDC_NAMES 1 before including Microsoft provided C Runtime +// library to expose declarations without "_" prefix to match POSIX style. +#define _CRT_INTERNAL_NONSTDC_NAMES 1 +#include +#include +#else +#include +#include +#endif +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +// in Windows doesn't define S_ISDIR macro +#if !defined(S_ISDIR) && defined(S_IFMT) && defined(S_IFDIR) +#define S_ISDIR(m) (((m)&S_IFMT) == S_IFDIR) +#endif +#define F_OK 0 +#endif + +namespace triton { namespace backend { + +#ifdef TRITON_ENABLE_GPU +void CUDART_CB +MemcpyHost(void* args) +{ + auto* copy_params = reinterpret_cast(args); + memcpy(copy_params->dst_, copy_params->src_, copy_params->byte_size_); + delete copy_params; +} +#endif // TRITON_ENABLE_GPU + +TRITONSERVER_MemoryType +GetUsePinnedMemoryType(TRITONSERVER_MemoryType ref_buffer_type) +{ + // The following matrix is used for both input and output. + // src \ dest | non-pinned | pinned | device + // non-pinned | memcpy | memcpy | buffer needed + // pinned | memcpy | memcpy | cudaMemcpy + // device | buffer needed | cudaMemcpy | cudaMemcpy + if (ref_buffer_type == TRITONSERVER_MEMORY_CPU_PINNED) { + return TRITONSERVER_MEMORY_CPU_PINNED; + } + + return (ref_buffer_type == TRITONSERVER_MEMORY_CPU) ? TRITONSERVER_MEMORY_GPU + : TRITONSERVER_MEMORY_CPU; +} + +TRITONSERVER_Error_Code +StatusCodeToTritonCode(triton::common::Error::Code error_code) +{ + switch (error_code) { + case triton::common::Error::Code::UNKNOWN: + return TRITONSERVER_ERROR_UNKNOWN; + case triton::common::Error::Code::INTERNAL: + return TRITONSERVER_ERROR_INTERNAL; + case triton::common::Error::Code::NOT_FOUND: + return TRITONSERVER_ERROR_NOT_FOUND; + case triton::common::Error::Code::INVALID_ARG: + return TRITONSERVER_ERROR_INVALID_ARG; + case triton::common::Error::Code::UNAVAILABLE: + return TRITONSERVER_ERROR_UNAVAILABLE; + case triton::common::Error::Code::UNSUPPORTED: + return TRITONSERVER_ERROR_UNSUPPORTED; + case triton::common::Error::Code::ALREADY_EXISTS: + return TRITONSERVER_ERROR_ALREADY_EXISTS; + + default: + break; + } + + return TRITONSERVER_ERROR_UNKNOWN; +} + +TRITONSERVER_Error* +CommonErrorToTritonError(triton::common::Error error) +{ + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(error.ErrorCode()), error.Message().c_str()); +} + +TRITONSERVER_Error* +ParseShape( + common::TritonJson::Value& io, const std::string& name, + std::vector* shape) +{ + common::TritonJson::Value shape_array; + RETURN_IF_ERROR(io.MemberAsArray(name.c_str(), &shape_array)); + for (size_t i = 0; i < shape_array.ArraySize(); ++i) { + int64_t d = 0; + RETURN_IF_ERROR(shape_array.IndexAsInt(i, &d)); + shape->push_back(d); + } + + return nullptr; // success +} + +std::string +ShapeToString(const int64_t* dims, const size_t dims_count) +{ + bool first = true; + + std::string str("["); + for (size_t i = 0; i < dims_count; ++i) { + const int64_t dim = dims[i]; + if (!first) { + str += ","; + } + str += std::to_string(dim); + first = false; + } + + str += "]"; + return str; +} + +std::string +ShapeToString(const std::vector& shape) +{ + return ShapeToString(shape.data(), shape.size()); +} + +int64_t +GetElementCount(const int64_t* dims, const size_t dims_count) +{ + bool first = true; + int64_t cnt = 0; + for (size_t i = 0; i < dims_count; i++) { + if (dims[i] == WILDCARD_DIM) { + return -1; + } + + if (first) { + cnt = dims[i]; + first = false; + } else { + cnt *= dims[i]; + } + } + + return cnt; +} + +int64_t +GetElementCount(const std::vector& shape) +{ + return GetElementCount(shape.data(), shape.size()); +} + +int64_t +GetByteSize( + const TRITONSERVER_DataType& dtype, const std::vector& dims) +{ + size_t dt_size = TRITONSERVER_DataTypeByteSize(dtype); + if (dt_size == 0) { + return -1; + } + + int64_t cnt = GetElementCount(dims); + if (cnt == -1) { + return -1; + } + + return cnt * dt_size; +} + +TRITONSERVER_Error* +ReadInputTensor( + TRITONBACKEND_Request* request, const std::string& input_name, char* buffer, + size_t* buffer_byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id, cudaStream_t cuda_stream, bool* cuda_used, + const char* host_policy_name, const bool copy_on_stream) +{ + TRITONBACKEND_Input* input; + RETURN_IF_ERROR( + TRITONBACKEND_RequestInput(request, input_name.c_str(), &input)); + + uint64_t input_byte_size; + uint32_t input_buffer_count; + RETURN_IF_ERROR(TRITONBACKEND_InputPropertiesForHostPolicy( + input, host_policy_name, nullptr, nullptr, nullptr, nullptr, + &input_byte_size, &input_buffer_count)); + RETURN_ERROR_IF_FALSE( + input_byte_size <= *buffer_byte_size, TRITONSERVER_ERROR_INVALID_ARG, + std::string( + GetRequestId(request) + "buffer too small for input tensor '" + + input_name + "', " + std::to_string(*buffer_byte_size) + " < " + + std::to_string(input_byte_size))); + + size_t output_buffer_offset = 0; + for (uint32_t b = 0; b < input_buffer_count; ++b) { + const void* input_buffer = nullptr; + uint64_t input_buffer_byte_size = 0; + TRITONSERVER_MemoryType input_memory_type = TRITONSERVER_MEMORY_CPU; + int64_t input_memory_type_id = 0; + + RETURN_IF_ERROR(TRITONBACKEND_InputBufferForHostPolicy( + input, host_policy_name, b, &input_buffer, &input_buffer_byte_size, + &input_memory_type, &input_memory_type_id)); + + RETURN_IF_ERROR(CopyBuffer( + "Failed to copy buffer", input_memory_type, input_memory_type_id, + memory_type, memory_type_id, input_buffer_byte_size, input_buffer, + buffer + output_buffer_offset, cuda_stream, cuda_used, copy_on_stream)); + + output_buffer_offset += input_buffer_byte_size; + } + + *buffer_byte_size = input_byte_size; + + return nullptr; // success +} + +TRITONSERVER_Error* +ReadInputTensor( + TRITONBACKEND_Request* request, const std::string& input_name, char* buffer, + size_t* buffer_byte_size, const char* host_policy_name) +{ + bool cuda_used; + return ReadInputTensor( + request, input_name, buffer, buffer_byte_size, + TRITONSERVER_MEMORY_CPU /* memory_type */, 0 /* memory_type_id */, + 0 /* cuda_stream */, &cuda_used); +} + +TRITONSERVER_Error* +CheckAllowedModelInput( + common::TritonJson::Value& io, const std::set& allowed) +{ + std::string io_name; + RETURN_IF_ERROR(io.MemberAsString("name", &io_name)); + if (allowed.find(io_name) == allowed.end()) { + std::string astr; + for (const auto& a : allowed) { + if (!astr.empty()) { + astr.append(", "); + } + astr.append(a); + } + + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "unexpected inference input '" + io_name + + "', allowed inputs are: " + astr) + .c_str()); + } + return nullptr; // success +} + +TRITONSERVER_Error* +CheckAllowedModelOutput( + common::TritonJson::Value& io, const std::set& allowed) +{ + std::string io_name; + RETURN_IF_ERROR(io.MemberAsString("name", &io_name)); + if (allowed.find(io_name) == allowed.end()) { + std::string astr; + for (const auto& a : allowed) { + if (!astr.empty()) { + astr.append(", "); + } + astr.append(a); + } + + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "unexpected inference output '" + io_name + + "', allowed outputs are: " + astr) + .c_str()); + } + + return nullptr; // success +} + +TRITONSERVER_Error* +GetBooleanSequenceControlProperties( + common::TritonJson::Value& batcher, const std::string& model_name, + const std::string& control_kind, const bool required, + std::string* tensor_name, std::string* tensor_datatype, + float* fp32_false_value, float* fp32_true_value, int32_t* int32_false_value, + int32_t* int32_true_value, bool* bool_false_value, bool* bool_true_value) +{ + // Make sure same tensor is not configured for multiple controls + std::set seen_tensors; + + // Make sure the control kind is not mentioned multiple times. + bool seen_control = false; + + common::TritonJson::Value control_inputs; + if (batcher.Find("control_input", &control_inputs)) { + for (size_t ci_idx = 0; ci_idx < control_inputs.ArraySize(); ci_idx++) { + common::TritonJson::Value control_input; + RETURN_IF_ERROR(control_inputs.IndexAsObject(ci_idx, &control_input)); + std::string input_name; + RETURN_IF_ERROR(control_input.MemberAsString("name", &input_name)); + if (input_name.empty()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string( + "sequence batching control tensor must have a name for ") + + model_name) + .c_str()); + } + + if (seen_tensors.find(input_name) != seen_tensors.end()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("sequence batching control tensor '") + input_name + + "' is specified for multiple control kinds for " + model_name) + .c_str()); + } + + seen_tensors.insert(input_name); + common::TritonJson::Value controls; + if (control_input.Find("control", &controls)) { + for (size_t c_idx = 0; c_idx < controls.ArraySize(); c_idx++) { + common::TritonJson::Value c; + RETURN_IF_ERROR(controls.IndexAsObject(c_idx, &c)); + std::string kind_str; + RETURN_IF_ERROR(c.MemberAsString("kind", &kind_str)); + if (kind_str == control_kind) { + if (seen_control) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string( + "sequence batching specifies multiple " + control_kind + + " tensors for " + model_name) + .c_str())); + } + + *tensor_name = input_name; + seen_control = true; + + common::TritonJson::Value int32_false_true, fp32_false_true, + bool_false_true; + bool found_int32 = + (c.Find("int32_false_true", &int32_false_true) && + (int32_false_true.ArraySize() > 0)); + bool found_fp32 = + (c.Find("fp32_false_true", &fp32_false_true) && + (fp32_false_true.ArraySize() > 0)); + bool found_bool = + (c.Find("bool_false_true", &bool_false_true) && + (bool_false_true.ArraySize() > 0)); + + // Make sure only one of int, float, or bool type is specified. + if (!(found_int32 || found_fp32 || found_bool)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string( + "sequence batching must specify either " + "'int32_false_true', 'fp32_false_true' or " + "'bool_false_true' for " + + control_kind + " for " + model_name)) + .c_str()); + } else if ( + (found_fp32 && found_int32) || (found_fp32 && found_bool) || + (found_int32 && found_bool)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string( + "sequence batching specifies more than one from " + "'int32_false_true', 'fp32_false_true' and " + "'bool_false_true' for " + + control_kind + " for " + model_name)) + .c_str()); + } + + if (found_int32) { + if (int32_false_true.ArraySize() != 2) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string( + "sequence batching control 'int32_false_true' must " + "have " + "exactly 2 entries for " + + control_kind + " for " + model_name)) + .c_str()); + } + if (tensor_datatype != nullptr) { + *tensor_datatype = "TYPE_INT32"; + } + if (int32_false_value != nullptr) { + int64_t value; + RETURN_IF_ERROR(int32_false_true.IndexAsInt(0, &value)); + *int32_false_value = value; + } + if (int32_true_value != nullptr) { + int64_t value; + RETURN_IF_ERROR(int32_false_true.IndexAsInt(1, &value)); + *int32_true_value = value; + } + } else if (found_fp32) { + if (fp32_false_true.ArraySize() != 2) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string( + "sequence batching control 'fp32_false_true' must " + "have exactly " + "2 entries for " + + control_kind + " for " + model_name)) + .c_str()); + } + if (tensor_datatype != nullptr) { + *tensor_datatype = "TYPE_FP32"; + } + if (fp32_false_value != nullptr) { + double value = 0.0; + RETURN_IF_ERROR(fp32_false_true.IndexAsDouble(0, &value)); + *fp32_false_value = value; + } + if (fp32_true_value != nullptr) { + double value = 0.0; + RETURN_IF_ERROR(fp32_false_true.IndexAsDouble(1, &value)); + *fp32_true_value = value; + } + } else { + if (bool_false_true.ArraySize() != 2) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string( + "sequence batching control 'bool_false_true' must " + "have exactly " + "2 entries for " + + control_kind + " for " + model_name)) + .c_str()); + } + if (tensor_datatype != nullptr) { + *tensor_datatype = "TYPE_BOOL"; + } + if (bool_false_value != nullptr) { + bool value; + RETURN_IF_ERROR(bool_false_true.IndexAsBool(0, &value)); + *bool_false_value = value; + } + if (bool_true_value != nullptr) { + bool value; + RETURN_IF_ERROR(bool_false_true.IndexAsBool(1, &value)); + *bool_true_value = value; + } + } + } + } + } + } + } + + if (!seen_control) { + if (required) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string( + "sequence batching control tensor must specify a " + + control_kind + " value for " + model_name)) + .c_str()); + } + + tensor_name->clear(); + } + + return nullptr; // success +} + +TRITONSERVER_Error* +GetTypedSequenceControlProperties( + common::TritonJson::Value& batcher, const std::string& model_name, + const std::string& control_kind, const bool required, + std::string* tensor_name, std::string* tensor_datatype) +{ + // Make sure same tensor is not configured for multiple controls + std::set seen_tensors; + + // Make sure the control kind is not mentioned multiple times. + bool seen_control = false; + + common::TritonJson::Value control_inputs; + if (batcher.Find("control_input", &control_inputs)) { + for (size_t ci_idx = 0; ci_idx < control_inputs.ArraySize(); ci_idx++) { + common::TritonJson::Value control_input; + RETURN_IF_ERROR(control_inputs.IndexAsObject(ci_idx, &control_input)); + std::string input_name; + RETURN_IF_ERROR(control_input.MemberAsString("name", &input_name)); + if (input_name.empty()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string( + "sequence batching control tensor must have a name for ") + + model_name) + .c_str()); + } + if (seen_tensors.find(input_name) != seen_tensors.end()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("sequence batching control tensor '") + input_name + + "' is specified for multiple control kinds for " + model_name) + .c_str()); + } + + seen_tensors.insert(input_name); + common::TritonJson::Value controls; + if (control_input.Find("control", &controls)) { + for (size_t c_idx = 0; c_idx < controls.ArraySize(); c_idx++) { + common::TritonJson::Value c; + RETURN_IF_ERROR(controls.IndexAsObject(c_idx, &c)); + std::string kind_str; + RETURN_IF_ERROR(c.MemberAsString("kind", &kind_str)); + if (kind_str == control_kind) { + if (seen_control) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string( + "sequence batching specifies multiple " + control_kind + + " tensors for " + model_name) + .c_str())); + } + + *tensor_name = input_name; + if (tensor_datatype != nullptr) { + RETURN_IF_ERROR(c.MemberAsString("data_type", tensor_datatype)); + } + + seen_control = true; + + common::TritonJson::Value int32_false_true, fp32_false_true, + bool_false_true; + bool found_int32 = + (c.Find("int32_false_true", &int32_false_true) && + (int32_false_true.ArraySize() > 0)); + bool found_fp32 = + (c.Find("fp32_false_true", &fp32_false_true) && + (fp32_false_true.ArraySize() > 0)); + bool found_bool = + (c.Find("bool_false_true", &bool_false_true) && + (bool_false_true.ArraySize() > 0)); + if (found_fp32 || found_int32 || found_bool) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string( + "sequence batching must not specify either " + "'int32_false_true', 'fp32_false_true' or " + "'bool_false_true' for " + + control_kind + " for " + model_name)) + .c_str()); + } + } + } + } + } + } + + if (!seen_control) { + if (required) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string( + "sequence batching control tensor must specify a " + + control_kind + " value for " + model_name)) + .c_str()); + } + + tensor_name->clear(); + } + + return nullptr; // success +} + +void +RequestsRespondWithError( + TRITONBACKEND_Request** requests, const uint32_t request_count, + TRITONSERVER_Error* response_err, const bool release_request) +{ + for (size_t i = 0; i < request_count; i++) { + TRITONBACKEND_Response* response; + auto err = TRITONBACKEND_ResponseNew(&response, requests[i]); + if (err != nullptr) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + (GetRequestId(requests[i]) + "fail to create response").c_str()); + TRITONSERVER_ErrorDelete(err); + } else { + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend( + response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, response_err), + (GetRequestId(requests[i]) + "fail to send error response").c_str()); + } + + if (release_request) { + LOG_IF_ERROR( + TRITONBACKEND_RequestRelease( + requests[i], TRITONSERVER_REQUEST_RELEASE_ALL), + "fail to release request"); + requests[i] = nullptr; + } + } + + TRITONSERVER_ErrorDelete(response_err); +} + +void +SendErrorForResponses( + std::vector* responses, + const uint32_t response_count, TRITONSERVER_Error* response_err) +{ + for (size_t i = 0; i < response_count; i++) { + TRITONBACKEND_Response* response = (*responses)[i]; + if (response != nullptr) { + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend( + response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, response_err), + "fail to send error response"); + (*responses)[i] = nullptr; + } + } + + TRITONSERVER_ErrorDelete(response_err); +} + +TRITONSERVER_Error* +CopyBuffer( + const std::string& msg, const TRITONSERVER_MemoryType src_memory_type, + const int64_t src_memory_type_id, + const TRITONSERVER_MemoryType dst_memory_type, + const int64_t dst_memory_type_id, const size_t byte_size, const void* src, + void* dst, cudaStream_t cuda_stream, bool* cuda_used, + const bool copy_on_stream) +{ + *cuda_used = false; + + if (byte_size > 0) { + if (src == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string( + msg + ": attempted a copy of " + std::to_string(byte_size) + + " Bytes from an uninitialized memory") + .c_str()); + } + + if (dst == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string( + msg + ": attempted a copy of " + std::to_string(byte_size) + + " Bytes to an uninitialized memory") + .c_str()); + } + } + + + // For CUDA memcpy, if copy_on_stream is false, all host to host copy will be + // blocked in respect to the host, so use memcpy() directly. In this case, + // need to be careful on whether the src buffer is valid. + if ((src_memory_type != TRITONSERVER_MEMORY_GPU) && + (dst_memory_type != TRITONSERVER_MEMORY_GPU)) { +#ifdef TRITON_ENABLE_GPU + if (copy_on_stream) { + auto params = new CopyParams(dst, src, byte_size); + cudaLaunchHostFunc( + cuda_stream, MemcpyHost, reinterpret_cast(params)); + *cuda_used = true; + } else { + memcpy(dst, src, byte_size); + } +#else + memcpy(dst, src, byte_size); +#endif // TRITON_ENABLE_GPU + } else { +#ifdef TRITON_ENABLE_GPU + // [TODO] use cudaMemcpyDefault if UVM is supported for the device + auto copy_kind = cudaMemcpyDeviceToDevice; + if (src_memory_type != TRITONSERVER_MEMORY_GPU) { + copy_kind = cudaMemcpyHostToDevice; + } else if (dst_memory_type != TRITONSERVER_MEMORY_GPU) { + copy_kind = cudaMemcpyDeviceToHost; + } + + if ((src_memory_type_id != dst_memory_type_id) && + (copy_kind == cudaMemcpyDeviceToDevice)) { + RETURN_IF_CUDA_ERROR( + cudaMemcpyPeerAsync( + dst, dst_memory_type_id, src, src_memory_type_id, byte_size, + cuda_stream), + TRITONSERVER_ERROR_INTERNAL, msg + ": failed to perform CUDA copy"); + } else { + RETURN_IF_CUDA_ERROR( + cudaMemcpyAsync(dst, src, byte_size, copy_kind, cuda_stream), + TRITONSERVER_ERROR_INTERNAL, msg + ": failed to perform CUDA copy"); + } + + *cuda_used = true; +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string(msg + ": try to use CUDA copy while GPU is not supported") + .c_str()); +#endif // TRITON_ENABLE_GPU + } + + return nullptr; // success +} + +TRITONSERVER_Error* +GetDirectoryContents(const std::string& path, std::set* contents) +{ +#ifdef _WIN32 + WIN32_FIND_DATA entry; + HANDLE dir = FindFirstFile(path.c_str(), &entry); + if (dir == INVALID_HANDLE_VALUE) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("failed to open directory: ") + path).c_str()); + } + if ((entry.cFileName != ".") && (entry.cFileName != "..")) { + contents->insert(entry.cFileName); + } + while (FindNextFileA(dir, &entry)) { + if ((entry.cFileName != ".") && (entry.cFileName != "..")) { + contents->insert(entry.cFileName); + } + } + + FindClose(dir); +#else + DIR* dir = opendir(path.c_str()); + if (dir == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("failed to open directory: ") + path).c_str()); + } + + struct dirent* entry; + while ((entry = readdir(dir)) != nullptr) { + std::string entryname = entry->d_name; + if ((entryname != ".") && (entryname != "..")) { + contents->insert(entryname); + } + } + + closedir(dir); +#endif + return nullptr; // success +} + +TRITONSERVER_Error* +FileExists(const std::string& path, bool* exists) +{ + *exists = (access(path.c_str(), F_OK) == 0); + return nullptr; // success +} + +TRITONSERVER_Error* +ReadTextFile(const std::string& path, std::string* contents) +{ + std::ifstream in(path, std::ios::in | std::ios::binary); + if (!in) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + ("failed to open/read file '" + path + "': " + strerror(errno)) + .c_str()); + } + + in.seekg(0, std::ios::end); + contents->resize(in.tellg()); + in.seekg(0, std::ios::beg); + in.read(&(*contents)[0], contents->size()); + in.close(); + + return nullptr; // success +} + +TRITONSERVER_Error* +IsDirectory(const std::string& path, bool* is_dir) +{ + *is_dir = false; + + struct stat st; + if (stat(path.c_str(), &st) != 0) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("failed to stat file ") + path).c_str()); + } + + *is_dir = S_ISDIR(st.st_mode); + return nullptr; // success +} + +std::string +JoinPath(std::initializer_list segments) +{ + std::string joined; + + for (const auto& seg : segments) { + if (joined.empty()) { + joined = seg; + } else if (!seg.empty() && (seg[0] == '/')) { // IsAbsolutePath(seg) + if (joined[joined.size() - 1] == '/') { + joined.append(seg.substr(1)); + } else { + joined.append(seg); + } + } else { // !IsAbsolutePath(seg) + if (joined[joined.size() - 1] != '/') { + joined.append("/"); + } + joined.append(seg); + } + } + + return joined; +} + +TRITONSERVER_Error* +ModelPaths( + const std::string& model_repository_path, uint64_t version, + const bool ignore_directories, const bool ignore_files, + std::unordered_map* model_paths) +{ + std::set model_files; + // Read all the files in 'path' and filter by type for different requirements + auto path = JoinPath({model_repository_path, std::to_string(version)}); + RETURN_IF_ERROR(GetDirectoryContents(path, &model_files)); + if (ignore_directories) { + // Erase directory entries... + for (auto iter = model_files.begin(); iter != model_files.end();) { + bool is_dir; + RETURN_IF_ERROR(IsDirectory(JoinPath({path, *iter}), &is_dir)); + if (is_dir) { + iter = model_files.erase(iter); + } else { + ++iter; + } + } + } + if (ignore_files) { + // Erase non-directory entries... + for (auto iter = model_files.begin(); iter != model_files.end();) { + bool is_dir; + RETURN_IF_ERROR(IsDirectory(JoinPath({path, *iter}), &is_dir)); + if (!is_dir) { + iter = model_files.erase(iter); + } else { + ++iter; + } + } + } + + for (const auto& filename : model_files) { + const auto model_path = JoinPath({path, filename}); + model_paths->emplace( + std::piecewise_construct, std::make_tuple(filename), + std::make_tuple(model_path)); + } + + return nullptr; // success +} + +TRITONSERVER_Error* +CreateCudaStream( + const int device_id, const int cuda_stream_priority, cudaStream_t* stream) +{ + *stream = nullptr; + +#ifdef TRITON_ENABLE_GPU + // Make sure that correct device is set before creating stream and + // then restore the device to what was set by the caller. + int current_device; + auto cuerr = cudaGetDevice(¤t_device); + bool overridden = false; + if (cuerr == cudaSuccess) { + overridden = (current_device != device_id); + if (overridden) { + cuerr = cudaSetDevice(device_id); + } + } + + if (cuerr == cudaSuccess) { + cuerr = cudaStreamCreateWithPriority( + stream, cudaStreamDefault, cuda_stream_priority); + } + + if (overridden) { + cudaSetDevice(current_device); + } + + if (cuerr != cudaSuccess) { + *stream = nullptr; + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("unable to create stream: ") + cudaGetErrorString(cuerr)) + .c_str()); + } +#endif // TRITON_ENABLE_GPU + + return nullptr; // success +} + +TRITONSERVER_Error* +ParseLongLongValue(const std::string& value, int64_t* parsed_value) +{ + try { + *parsed_value = std::stoll(value); + } + catch (const std::invalid_argument& ia) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("failed to convert '") + value + + "' to long long integral number") + .c_str()); + } + + return nullptr; // success +} + +TRITONSERVER_Error* +ParseUnsignedLongLongValue(const std::string& value, uint64_t* parsed_value) +{ + try { + *parsed_value = std::stoull(value); + } + catch (const std::invalid_argument& ia) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("failed to convert '") + value + + "' to unsigned long long integral number") + .c_str()); + } + + return nullptr; // success +} + +TRITONSERVER_Error* +ParseBoolValue(const std::string& value, bool* parsed_value) +{ + std::string lvalue = value; + std::transform( + lvalue.begin(), lvalue.end(), lvalue.begin(), + [](unsigned char c) { return std::tolower(c); }); + + if ((lvalue == "true") || (lvalue == "on") || (lvalue == "1")) { + *parsed_value = true; + return nullptr; // success + } + if ((lvalue == "false") || (lvalue == "off") || (lvalue == "0")) { + *parsed_value = false; + return nullptr; // success + } + + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("failed to convert '") + value + "' to boolean").c_str()); +} + +TRITONSERVER_Error* +ParseIntValue(const std::string& value, int* parsed_value) +{ + try { + *parsed_value = std::stoi(value); + } + catch (const std::invalid_argument& ia) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("failed to convert '") + value + "' to integral number") + .c_str()); + } + + return nullptr; // success +} + +TRITONSERVER_Error* +ParseDoubleValue(const std::string& value, double* parsed_value) +{ + try { + *parsed_value = std::stod(value); + } + catch (const std::invalid_argument& ia) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("failed to convert '") + value + "' to double number") + .c_str()); + } + + return nullptr; // success +} + +TRITONSERVER_Error* +GetParameterValue( + triton::common::TritonJson::Value& params, const std::string& key, + std::string* value) +{ + triton::common::TritonJson::Value json_value; + RETURN_ERROR_IF_FALSE( + params.Find(key.c_str(), &json_value), TRITONSERVER_ERROR_NOT_FOUND, + std::string("model configuration is missing the parameter ") + key); + RETURN_IF_ERROR(json_value.MemberAsString("string_value", value)); + return nullptr; // success +} + +TRITONSERVER_Error* +BatchInput::ParseFromModelConfig( + triton::common::TritonJson::Value& config, + std::vector* batch_inputs) +{ + batch_inputs->clear(); + triton::common::TritonJson::Value bis; + RETURN_IF_ERROR(config.MemberAsArray("batch_input", &bis)); + for (size_t i = 0; i < bis.ArraySize(); ++i) { + triton::common::TritonJson::Value bi; + RETURN_IF_ERROR(bis.IndexAsObject(i, &bi)); + batch_inputs->emplace_back(); + RETURN_IF_ERROR(batch_inputs->back().Init(bi)); + } + + return nullptr; // success +} + +TRITONSERVER_Error* +BatchInput::Init(triton::common::TritonJson::Value& bi_config) +{ + { + triton::common::TritonJson::Value bi_target_names; + RETURN_IF_ERROR(bi_config.MemberAsArray("target_name", &bi_target_names)); + for (size_t i = 0; i < bi_target_names.ArraySize(); ++i) { + std::string tn; + RETURN_IF_ERROR(bi_target_names.IndexAsString(i, &tn)); + target_names_.emplace_back(std::move(tn)); + } + } + { + RETURN_IF_ERROR(bi_config.MemberAsString("kind", &kind_str_)); + if (kind_str_ == "BATCH_ELEMENT_COUNT") { + kind_ = Kind::BATCH_ELEMENT_COUNT; + } else if (kind_str_ == "BATCH_ACCUMULATED_ELEMENT_COUNT") { + kind_ = Kind::BATCH_ACCUMULATED_ELEMENT_COUNT; + } else if (kind_str_ == "BATCH_ACCUMULATED_ELEMENT_COUNT_WITH_ZERO") { + kind_ = Kind::BATCH_ACCUMULATED_ELEMENT_COUNT_WITH_ZERO; + } else if (kind_str_ == "BATCH_MAX_ELEMENT_COUNT_AS_SHAPE") { + kind_ = Kind::BATCH_MAX_ELEMENT_COUNT_AS_SHAPE; + } else if (kind_str_ == "BATCH_ITEM_SHAPE") { + kind_ = Kind::BATCH_ITEM_SHAPE; + } else if (kind_str_ == "BATCH_ITEM_SHAPE_FLATTEN") { + kind_ = Kind::BATCH_ITEM_SHAPE_FLATTEN; + } else { + RETURN_ERROR_IF_FALSE( + false, TRITONSERVER_ERROR_INVALID_ARG, + std::string("unexpected batch input kind '" + kind_str_ + "'")); + } + } + { + std::string bi_dtype; + RETURN_IF_ERROR(bi_config.MemberAsString("data_type", &bi_dtype)); + data_type_ = ModelConfigDataTypeToTritonServerDataType(bi_dtype); + RETURN_ERROR_IF_TRUE( + data_type_ == TRITONSERVER_TYPE_INVALID, TRITONSERVER_ERROR_INVALID_ARG, + std::string("unexpected batch input data type '" + bi_dtype + "'")); + } + { + triton::common::TritonJson::Value bi_source_inputs; + RETURN_IF_ERROR(bi_config.MemberAsArray("source_input", &bi_source_inputs)); + for (size_t i = 0; i < bi_source_inputs.ArraySize(); ++i) { + std::string si; + RETURN_IF_ERROR(bi_source_inputs.IndexAsString(i, &si)); + source_inputs_.emplace_back(std::move(si)); + } + } + return nullptr; // success +} + +TRITONSERVER_DataType +ModelConfigDataTypeToTritonServerDataType(const std::string& data_type_str) +{ + // Must start with "TYPE_". + if (data_type_str.rfind("TYPE_", 0) != 0) { + return TRITONSERVER_TYPE_INVALID; + } + + const std::string dtype = data_type_str.substr(strlen("TYPE_")); + + if (dtype == "BOOL") { + return TRITONSERVER_TYPE_BOOL; + } else if (dtype == "UINT8") { + return TRITONSERVER_TYPE_UINT8; + } else if (dtype == "UINT16") { + return TRITONSERVER_TYPE_UINT16; + } else if (dtype == "UINT32") { + return TRITONSERVER_TYPE_UINT32; + } else if (dtype == "UINT64") { + return TRITONSERVER_TYPE_UINT64; + } else if (dtype == "INT8") { + return TRITONSERVER_TYPE_INT8; + } else if (dtype == "INT16") { + return TRITONSERVER_TYPE_INT16; + } else if (dtype == "INT32") { + return TRITONSERVER_TYPE_INT32; + } else if (dtype == "INT64") { + return TRITONSERVER_TYPE_INT64; + } else if (dtype == "FP16") { + return TRITONSERVER_TYPE_FP16; + } else if (dtype == "FP32") { + return TRITONSERVER_TYPE_FP32; + } else if (dtype == "FP64") { + return TRITONSERVER_TYPE_FP64; + } else if (dtype == "STRING") { + return TRITONSERVER_TYPE_BYTES; + } else if (dtype == "BF16") { + return TRITONSERVER_TYPE_BF16; + } + + return TRITONSERVER_TYPE_INVALID; +} + +TRITONSERVER_Error* +BatchOutput::ParseFromModelConfig( + triton::common::TritonJson::Value& config, + std::vector* batch_outputs) +{ + batch_outputs->clear(); + triton::common::TritonJson::Value bos; + RETURN_IF_ERROR(config.MemberAsArray("batch_output", &bos)); + for (size_t i = 0; i < bos.ArraySize(); ++i) { + batch_outputs->emplace_back(); + auto& batch_output = batch_outputs->back(); + triton::common::TritonJson::Value bo; + RETURN_IF_ERROR(bos.IndexAsObject(i, &bo)); + { + triton::common::TritonJson::Value bo_target_names; + RETURN_IF_ERROR(bo.MemberAsArray("target_name", &bo_target_names)); + for (size_t i = 0; i < bo_target_names.ArraySize(); ++i) { + std::string tn; + RETURN_IF_ERROR(bo_target_names.IndexAsString(i, &tn)); + batch_output.target_names_.emplace_back(std::move(tn)); + } + } + { + std::string bo_kind; + RETURN_IF_ERROR(bo.MemberAsString("kind", &bo_kind)); + if (bo_kind == "BATCH_SCATTER_WITH_INPUT_SHAPE") { + batch_output.kind_ = Kind::BATCH_SCATTER_WITH_INPUT_SHAPE; + // Keep track of the output info for later cross reference with input + int64_t mbs = 0; + RETURN_IF_ERROR(config.MemberAsInt("max_batch_size", &mbs)); + if (mbs != 0) { + batch_output.shape_.push_back(-1); + } + triton::common::TritonJson::Value ios; + RETURN_IF_ERROR(config.MemberAsArray("output", &ios)); + for (size_t i = 0; i < ios.ArraySize(); i++) { + triton::common::TritonJson::Value io; + RETURN_IF_ERROR(ios.IndexAsObject(i, &io)); + std::string io_name; + RETURN_IF_ERROR(io.MemberAsString("name", &io_name)); + if (io_name == batch_output.target_names_[0]) { + std::string io_dtype; + RETURN_IF_ERROR(io.MemberAsString("data_type", &io_dtype)); + batch_output.data_type_ = + ModelConfigDataTypeToTritonServerDataType(io_dtype); + // If a reshape is provided for the input then use that when + // validating that the model matches what is expected. + triton::common::TritonJson::Value reshape; + if (io.Find("reshape", &reshape)) { + RETURN_IF_ERROR( + ParseShape(reshape, "shape", &batch_output.shape_)); + } else { + RETURN_IF_ERROR(ParseShape(io, "dims", &batch_output.shape_)); + } + break; + } + } + } else { + RETURN_ERROR_IF_FALSE( + false, TRITONSERVER_ERROR_INVALID_ARG, + std::string("unexpected batch output kind '" + bo_kind + "'")); + } + } + { + triton::common::TritonJson::Value bo_source_inputs; + RETURN_IF_ERROR(bo.MemberAsArray("source_input", &bo_source_inputs)); + for (size_t i = 0; i < bo_source_inputs.ArraySize(); ++i) { + std::string si; + RETURN_IF_ERROR(bo_source_inputs.IndexAsString(i, &si)); + batch_output.source_inputs_.emplace_back(std::move(si)); + } + } + } + + return nullptr; // success +} + +TRITONSERVER_Error* +TryParseModelStringParameter( + triton::common::TritonJson::Value& params, const std::string& mkey, + std::string* value, const std::string& default_value) +{ + triton::common::TritonJson::Value json_value; + if (params.Find(mkey.c_str(), &json_value)) { + RETURN_IF_ERROR(json_value.MemberAsString("string_value", value)); + } else { + *value = default_value; + } + + return nullptr; // success +} + +TRITONSERVER_Error* +TryParseModelStringParameter( + triton::common::TritonJson::Value& params, const std::string& mkey, + int* value, const int& default_value) +{ + triton::common::TritonJson::Value json_value; + if (params.Find(mkey.c_str(), &json_value)) { + std::string string_value; + RETURN_IF_ERROR(json_value.MemberAsString("string_value", &string_value)); + return ParseIntValue(string_value, value); + } else { + *value = default_value; + return nullptr; // success + } +} + +TRITONSERVER_Error* +TryParseModelStringParameter( + triton::common::TritonJson::Value& params, const std::string& mkey, + bool* value, const bool& default_value) +{ + triton::common::TritonJson::Value json_value; + if (params.Find(mkey.c_str(), &json_value)) { + std::string string_value; + RETURN_IF_ERROR(json_value.MemberAsString("string_value", &string_value)); + return ParseBoolValue(string_value, value); + } else { + *value = default_value; + return nullptr; // success + } +} + +TRITONSERVER_Error* +TryParseModelStringParameter( + triton::common::TritonJson::Value& params, const std::string& mkey, + uint64_t* value, const uint64_t& default_value) +{ + triton::common::TritonJson::Value json_value; + if (params.Find(mkey.c_str(), &json_value)) { + std::string string_value; + RETURN_IF_ERROR(json_value.MemberAsString("string_value", &string_value)); + return ParseUnsignedLongLongValue(string_value, value); + } else { + *value = default_value; + return nullptr; // success + } +} + +namespace { + +template +TRITONSERVER_Error* +BufferAsTypedString( + std::string& str, const char* buffer, const size_t element_cnt) +{ + const T* vals = reinterpret_cast(buffer); + + str += "[ "; + for (size_t i = 0; i < element_cnt; ++i) { + const T& v = vals[i]; + if (i != 0) { + str += ", "; + } + str += std::to_string(v); + } + + str += " ]"; + + return nullptr; // success +} + +} // namespace + + +TRITONSERVER_Error* +BufferAsTypedString( + std::string& str, const char* buffer, size_t buffer_byte_size, + TRITONSERVER_DataType datatype) +{ + const size_t element_cnt = + buffer_byte_size / TRITONSERVER_DataTypeByteSize(datatype); + + switch (datatype) { + case TRITONSERVER_TYPE_UINT8: + return BufferAsTypedString(str, buffer, element_cnt); + case TRITONSERVER_TYPE_UINT16: + return BufferAsTypedString(str, buffer, element_cnt); + case TRITONSERVER_TYPE_UINT32: + return BufferAsTypedString(str, buffer, element_cnt); + case TRITONSERVER_TYPE_UINT64: + return BufferAsTypedString(str, buffer, element_cnt); + + case TRITONSERVER_TYPE_INT8: + return BufferAsTypedString(str, buffer, element_cnt); + case TRITONSERVER_TYPE_INT16: + return BufferAsTypedString(str, buffer, element_cnt); + case TRITONSERVER_TYPE_INT32: + return BufferAsTypedString(str, buffer, element_cnt); + case TRITONSERVER_TYPE_INT64: + return BufferAsTypedString(str, buffer, element_cnt); + + case TRITONSERVER_TYPE_FP32: + return BufferAsTypedString(str, buffer, element_cnt); + case TRITONSERVER_TYPE_FP64: + return BufferAsTypedString(str, buffer, element_cnt); + + default: + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + std::string("class result not available for output due to " + "unsupported type '") + + std::string(TRITONSERVER_DataTypeString(datatype)) + "'") + .c_str()); + } + + return nullptr; // success +} + +std::string +GetRequestId(TRITONBACKEND_Request* request) +{ + const char* request_id = nullptr; + LOG_IF_ERROR( + TRITONBACKEND_RequestId(request, &request_id), + "unable to retrieve request ID string"); + if ((request_id == nullptr) || (request_id[0] == '\0')) { + request_id = ""; + } + return std::string("[request id: ") + request_id + "] "; +} + +}} // namespace triton::backend diff --git a/3rdparty/backend-r22.12/src/backend_input_collector.cc b/3rdparty/backend-r22.12/src/backend_input_collector.cc new file mode 100644 index 0000000000000000000000000000000000000000..a6f0cebd7921fcfbb5ce0e820c2711d587a059dd --- /dev/null +++ b/3rdparty/backend-r22.12/src/backend_input_collector.cc @@ -0,0 +1,1310 @@ +// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "triton/backend/backend_input_collector.h" + +#include +#include "triton/backend/backend_common.h" +#ifdef TRITON_ENABLE_GPU +#include "kernel.h" +#endif // TRITON_ENABLE_GPU + +namespace triton { namespace backend { + +// +// BackendInputCollector::InputIterator +// + +BackendInputCollector::InputIterator::InputIterator( + TRITONBACKEND_Request** requests, const uint32_t request_count, + std::vector* responses, const char* input_name, + const char* host_policy_name, const bool coalesce_request_input) + : requests_(requests), request_count_(request_count), responses_(responses), + input_name_(input_name), host_policy_(host_policy_name), + coalesce_request_input_(coalesce_request_input), curr_request_idx_(0), + curr_buffer_idx_(0), reach_end_(false) +{ + auto& response = (*responses_)[curr_request_idx_]; + RESPOND_AND_SET_NULL_IF_ERROR( + &response, TRITONBACKEND_RequestInput( + requests_[curr_request_idx_], input_name_, &curr_input_)); + RESPOND_AND_SET_NULL_IF_ERROR( + &response, TRITONBACKEND_InputPropertiesForHostPolicy( + curr_input_, host_policy_, nullptr, nullptr, nullptr, + nullptr, nullptr, &curr_buffer_cnt_)); +} + +bool +BackendInputCollector::InputIterator::GetNextContiguousInput( + ContiguousBuffer* input) +{ + if (reach_end_ || (curr_buffer_idx_ >= curr_buffer_cnt_)) { + return false; + } + + // Get the first buffer + TRITONBACKEND_InputBufferForHostPolicy( + curr_input_, host_policy_, curr_buffer_idx_, + reinterpret_cast(&input->memory_desc_.buffer_), + &input->memory_desc_.byte_size_, &input->memory_desc_.memory_type_, + &input->memory_desc_.memory_type_id_); + ++curr_buffer_idx_; + input->start_request_idx_ = curr_request_idx_; + input->end_request_idx_ = curr_request_idx_; + if (!coalesce_request_input_) { + if (curr_buffer_idx_ >= curr_buffer_cnt_) { + ++curr_request_idx_; + if (curr_request_idx_ < request_count_) { + auto& response = (*responses_)[curr_request_idx_]; + RESPOND_AND_SET_NULL_IF_ERROR( + &response, + TRITONBACKEND_RequestInput( + requests_[curr_request_idx_], input_name_, &curr_input_)); + RESPOND_AND_SET_NULL_IF_ERROR( + &response, TRITONBACKEND_InputPropertiesForHostPolicy( + curr_input_, host_policy_, nullptr, nullptr, nullptr, + nullptr, nullptr, &curr_buffer_cnt_)); + // reset buffer idx + curr_buffer_idx_ = 0; + } else { + reach_end_ = true; + } + } + return true; + } + + do { + for (; curr_buffer_idx_ < curr_buffer_cnt_; ++curr_buffer_idx_) { + const void* next_buffer; + size_t next_buffer_byte_size; + TRITONSERVER_MemoryType next_memory_type; + int64_t next_memory_type_id; + TRITONBACKEND_InputBufferForHostPolicy( + curr_input_, host_policy_, curr_buffer_idx_, &next_buffer, + &next_buffer_byte_size, &next_memory_type, &next_memory_type_id); + if (((input->memory_desc_.buffer_ + input->memory_desc_.byte_size_) != + next_buffer) || + (input->memory_desc_.memory_type_ != next_memory_type) || + (input->memory_desc_.memory_type_id_ != next_memory_type_id)) { + return true; + } + input->memory_desc_.byte_size_ += next_buffer_byte_size; + input->end_request_idx_ = curr_request_idx_; + } + // Iterated all buffers for current request, check next + ++curr_request_idx_; + if (curr_request_idx_ < request_count_) { + auto& response = (*responses_)[curr_request_idx_]; + RESPOND_AND_SET_NULL_IF_ERROR( + &response, + TRITONBACKEND_RequestInput( + requests_[curr_request_idx_], input_name_, &curr_input_)); + RESPOND_AND_SET_NULL_IF_ERROR( + &response, TRITONBACKEND_InputPropertiesForHostPolicy( + curr_input_, host_policy_, nullptr, nullptr, nullptr, + nullptr, nullptr, &curr_buffer_cnt_)); + // reset buffer idx + curr_buffer_idx_ = 0; + } + } while (curr_request_idx_ < request_count_); + reach_end_ = true; + return true; +} + +// +// BackendInputCollector +// + +bool +BackendInputCollector::GetInputBufferIfContiguous( + const char* input_name, const char** buffer, size_t* buffer_byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id) +{ + *buffer = nullptr; + *buffer_byte_size = 0; + const char* expected_next_buffer = nullptr; + bool contiguous = true; + for (size_t idx = 0; idx < request_count_; idx++) { + auto& request = requests_[idx]; + auto& response = (*responses_)[idx]; + + TRITONBACKEND_Input* input; + RESPOND_AND_SET_NULL_IF_ERROR( + &response, TRITONBACKEND_RequestInput(request, input_name, &input)); + uint64_t byte_size; + uint32_t buffer_count; + RESPOND_AND_SET_NULL_IF_ERROR( + &response, TRITONBACKEND_InputPropertiesForHostPolicy( + input, host_policy_cstr_, nullptr, nullptr, nullptr, + nullptr, &byte_size, &buffer_count)); + for (size_t idx = 0; idx < buffer_count; ++idx) { + const void* src_buffer; + size_t src_byte_size; + TRITONSERVER_MemoryType src_memory_type; + int64_t src_memory_type_id; + + RESPOND_AND_SET_NULL_IF_ERROR( + &response, + TRITONBACKEND_InputBufferForHostPolicy( + input, host_policy_cstr_, idx, &src_buffer, &src_byte_size, + &src_memory_type, &src_memory_type_id)); + if (*buffer != nullptr) { + // If have seen the second buffer while coalescing input is not + // requested, treat the inputs are not contiguous + if (coalesce_request_input_ && (expected_next_buffer == src_buffer) && + (*memory_type == src_memory_type) && + (*memory_type_id == src_memory_type_id)) { + expected_next_buffer += src_byte_size; + } else { + contiguous = false; + } + // Want to know total buffer byte size even if it is not contiguous + *buffer_byte_size += src_byte_size; + } else { + *buffer = reinterpret_cast(src_buffer); + *memory_type = src_memory_type; + *memory_type_id = src_memory_type_id; + *buffer_byte_size = src_byte_size; + expected_next_buffer = *buffer + src_byte_size; + } + } + } + return contiguous; +} + +void +BackendInputCollector::ProcessTensor( + const char* input_name, char* buffer, const size_t buffer_byte_size, + const TRITONSERVER_MemoryType memory_type, const int64_t memory_type_id) +{ + // A value of CPU_PINNED indicates that pinned memory buffer is not + // needed for this tensor. Any other value indicates that a pinned + // memory buffer is needed when the target memory type matches + // 'use_pinned_memory_type'. + TRITONSERVER_MemoryType use_pinned_memory_type = + TRITONSERVER_MEMORY_CPU_PINNED; + if (pinned_enabled_) { + use_pinned_memory_type = GetUsePinnedMemoryType(memory_type); + } + const bool use_kernel = (kernel_buffer_threshold_ != 0); + + size_t buffer_offset = 0; + + InputIterator ii( + requests_, request_count_, responses_, input_name, host_policy_cstr_, + coalesce_request_input_); + ContiguousBuffer input; + while (ii.GetNextContiguousInput(&input)) { + // If there are pending copies from tensor buffer that is not + // contiguous with 'response's part of that buffer, then need to + // go ahead and perform the pending copies so that can start a new + // contiguous region if necessary. + if ((pending_pinned_byte_size_ > 0) && + (buffer_offset != + (pending_pinned_byte_size_ + pending_pinned_offset_))) { + need_sync_ |= FlushPendingPinned( + buffer, buffer_byte_size, memory_type, memory_type_id); + } + if ((pending_copy_kernel_buffer_byte_size_ > 0) && + (buffer_offset != (pending_copy_kernel_buffer_byte_size_ + + pending_copy_kernel_buffer_offset_))) { + need_sync_ |= FlushPendingCopyKernel( + buffer, buffer_byte_size, memory_type, memory_type_id); + } + + need_sync_ |= SetInputTensor( + input_name, input, buffer, buffer_byte_size, memory_type, + memory_type_id, buffer_offset, use_pinned_memory_type, use_kernel, + true); + + buffer_offset += input.memory_desc_.byte_size_; + } + + // Done with the tensor, flush any pending pinned copies. + need_sync_ |= + FlushPendingPinned(buffer, buffer_byte_size, memory_type, memory_type_id); + need_sync_ |= FlushPendingCopyKernel( + buffer, buffer_byte_size, memory_type, memory_type_id); +#ifdef TRITON_ENABLE_GPU + if (need_sync_ && (event_ != nullptr)) { + cudaEventRecord(event_, stream_); + } +#endif // TRITON_ENABLE_GPU +} + +TRITONSERVER_Error* +BackendInputCollector::ProcessTensor( + const char* input_name, char* buffer, const size_t buffer_byte_size, + const std::vector>& + allowed_input_types, + const char** dst_buffer, size_t* dst_buffer_byte_size, + TRITONSERVER_MemoryType* dst_memory_type, int64_t* dst_memory_type_id) +{ + if (buffer == nullptr) { + if (allowed_input_types.size() == 0) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "'allowed_input_types' must contain at least one pair of memory type " + "and id"); + } + if (GetInputBufferIfContiguous( + input_name, dst_buffer, dst_buffer_byte_size, dst_memory_type, + dst_memory_type_id)) { + // zero size buffer will be treated as contiguous as well, + // but we want to invoke backend memory to have a valid address. + if (*dst_buffer_byte_size != 0) { + // If the buffer is contiguous, check if the caller expects its type + for (const auto& allowed_type : allowed_input_types) { + if ((*dst_memory_type == allowed_type.first) && + ((*dst_memory_type_id == allowed_type.second))) { + return nullptr; // success + } + } + } + } + // A separate buffer is needed + BackendMemory* backend_memory = nullptr; + for (const auto& allowed_type : allowed_input_types) { + std::vector alloc_types; + const int64_t memory_type_id = allowed_type.second; + switch (allowed_type.first) { + case TRITONSERVER_MEMORY_GPU: + alloc_types = {BackendMemory::AllocationType::GPU_POOL, + BackendMemory::AllocationType::GPU}; + break; + case TRITONSERVER_MEMORY_CPU_PINNED: + alloc_types = {BackendMemory::AllocationType::CPU_PINNED_POOL, + BackendMemory::AllocationType::CPU_PINNED}; + break; + case TRITONSERVER_MEMORY_CPU: + alloc_types = {BackendMemory::AllocationType::CPU}; + break; + } + auto err = BackendMemory::Create( + memory_manager_, alloc_types, memory_type_id, *dst_buffer_byte_size, + &backend_memory); + if (err != nullptr) { + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("unable to create backend memory for type: ") + + TRITONSERVER_MemoryTypeString(allowed_type.first) + + " id: " + std::to_string(memory_type_id) + ": " + + TRITONSERVER_ErrorMessage(err)) + .c_str()); + TRITONSERVER_ErrorDelete(err); + } else { + in_use_memories_.emplace_back(backend_memory); + break; + } + } + if (backend_memory == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("failed to allocate contiguous buffer for input '") + + input_name + "'") + .c_str()); + } + buffer = backend_memory->MemoryPtr(); + *dst_buffer = backend_memory->MemoryPtr(); + *dst_buffer_byte_size = backend_memory->ByteSize(); + *dst_memory_type = backend_memory->MemoryType(); + *dst_memory_type_id = backend_memory->MemoryTypeId(); + } else { + if (allowed_input_types.size() != 1) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "'allowed_input_types' must only contain the memory type and id of " + "'buffer'"); + } + *dst_buffer = buffer; + *dst_buffer_byte_size = buffer_byte_size; + *dst_memory_type = allowed_input_types[0].first; + *dst_memory_type_id = allowed_input_types[0].second; + } + if (*dst_buffer_byte_size != 0) { + ProcessTensor( + input_name, buffer, *dst_buffer_byte_size, *dst_memory_type, + *dst_memory_type_id); + } + return nullptr; // success +} + +bool +BackendInputCollector::Finalize() +{ +#ifdef TRITON_ENABLE_GPU + if ((!deferred_pinned_.empty()) && need_sync_) { + if (event_ != nullptr) { + cudaEventSynchronize(event_); + } else { + cudaStreamSynchronize(stream_); + } + need_sync_ = false; + } +#endif // TRITON_ENABLE_GPU + + // After the above sync all the GPU->pinned copies are complete. Any + // deferred copies of pinned->CPU can now be done. +#ifdef TRITON_ENABLE_GPU + if (buffer_ready_event_ != nullptr) { + cudaEventSynchronize(buffer_ready_event_); + buffer_ready_event_ = nullptr; + } +#endif // TRITON_ENABLE_GPU + for (auto& def : deferred_pinned_) { + if (!def.finalized_) { + need_sync_ |= def.Finalize(stream_); + } + } + for (size_t i = 0; i < async_task_count_; i++) { + need_sync_ |= completion_queue_.Get(); + } + +#ifdef TRITON_ENABLE_GPU + // Record the new event location if deferred copies occur + if ((!deferred_pinned_.empty()) && need_sync_ && (event_ != nullptr)) { + cudaEventRecord(event_, stream_); + } +#endif // TRITON_ENABLE_GPU + + return need_sync_; +} + +bool +BackendInputCollector::DeferredPinned::Finalize(cudaStream_t stream) +{ + bool cuda_used = false; + auto err = CopyBuffer( + "pinned buffer", TRITONSERVER_MEMORY_CPU_PINNED, 0, tensor_memory_type_, + tensor_memory_id_, pinned_memory_size_, pinned_memory_, + tensor_buffer_ + tensor_buffer_offset_, stream, &cuda_used); + + // If something goes wrong with the copy all the pending + // responses fail... + if (err != nullptr) { + for (auto& pr : requests_) { + for (size_t idx = pr.start_request_idx_; idx <= pr.end_request_idx_; + ++idx) { + if ((*responses_)[idx] != nullptr) { + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend( + (*responses_)[idx], TRITONSERVER_RESPONSE_COMPLETE_FINAL, + err), + "failed to send error response"); + (*responses_)[idx] = nullptr; + } + } + } + TRITONSERVER_ErrorDelete(err); + } + return cuda_used; +} + +bool +BackendInputCollector::SetInputTensor( + const char* input_name, const ContiguousBuffer& input, char* tensor_buffer, + const size_t tensor_buffer_byte_size, + const TRITONSERVER_MemoryType tensor_memory_type, + const int64_t tensor_memory_type_id, const size_t tensor_buffer_offset, + const TRITONSERVER_MemoryType use_pinned_memory_type, const bool use_kernel, + const bool wait_buffer) +{ + bool cuda_copy = false; + + if ((tensor_buffer_offset + input.memory_desc_.byte_size_) > + tensor_buffer_byte_size) { + for (size_t i = input.start_request_idx_; i <= input.end_request_idx_; + ++i) { + RESPOND_AND_SET_NULL_IF_ERROR( + &(*responses_)[i], + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "unexpected total byte size " + + std::to_string( + tensor_buffer_offset + input.memory_desc_.byte_size_) + + " for input '" + input_name + "', expecting " + + std::to_string(tensor_buffer_byte_size)) + .c_str())); + } + return cuda_copy; + } + + // If the request buffer matches the memory type that should use an + // intermediate pinned memory buffer for the transfer, then just + // record the input as pending and increase the size required for + // the intermediate pinned buffer. We only do this check for the + // first buffer of an input and apply the same policy for all + // buffers. So if an inputs data is split over different memory + // types this may not be ideal but that should be a very rare + // situation. + if ((use_pinned_memory_type != TRITONSERVER_MEMORY_CPU_PINNED) && + (input.memory_desc_.memory_type_ == use_pinned_memory_type)) { + if (pending_pinned_byte_size_ == 0) { + pending_pinned_offset_ = tensor_buffer_offset; + } + + pending_pinned_byte_size_ += input.memory_desc_.byte_size_; + pending_pinned_input_buffers_.push_back(input); + return cuda_copy; + } + // [FIXME] support other direction if prove to be faster, all kernel + // handling code in this class asssumes the destination buffer is on device + // If the request buffer and the destination buffer are accessible by all + // GPUs (i.e. pinned, device), initiate the copy via copy CUDA kernel. + // We only do this check for the + // first buffer of an input and apply the same policy for all + // buffers. So if an inputs data is split over different memory + // types this may not be ideal but that should be a very rare + // situation. + // Currently checked direction: + // pinned -> device + // same device -> device + // different device -> device + if (use_kernel && + (input.memory_desc_.memory_type_ != TRITONSERVER_MEMORY_CPU) && + (tensor_memory_type == TRITONSERVER_MEMORY_GPU)) { + // [FIXME] Currently not allowing copy between devices as it requires + // peer-to-peer access to be enabled. Peer-to-peer is enabled by default, + // but server can still runs even if it fails to enable peer-to-peer. + // Should provide a utility to check whether a device pair allows direct + // access and use gather kernel accordingly + if ((input.memory_desc_.memory_type_ != TRITONSERVER_MEMORY_GPU) || + (input.memory_desc_.memory_type_id_ == tensor_memory_type_id)) { + if (pending_copy_kernel_buffer_byte_size_ == 0) { + pending_copy_kernel_buffer_offset_ = tensor_buffer_offset; + } + + pending_copy_kernel_buffer_byte_size_ += input.memory_desc_.byte_size_; + ++pending_copy_kernel_input_buffer_counts_; + pending_copy_kernel_input_buffers_.push_back(input); + return cuda_copy; + } + } + +#ifdef TRITON_ENABLE_GPU + if (wait_buffer && (buffer_ready_event_ != nullptr)) { + cudaEventSynchronize(buffer_ready_event_); + buffer_ready_event_ = nullptr; + } +#endif // TRITON_ENABLE_GPU + + // Direct copy without intermediate pinned memory. + bool cuda_used = false; + auto err = CopyBuffer( + input_name, input.memory_desc_.memory_type_, + input.memory_desc_.memory_type_id_, tensor_memory_type, + tensor_memory_type_id, input.memory_desc_.byte_size_, + input.memory_desc_.buffer_, tensor_buffer + tensor_buffer_offset, stream_, + &cuda_used, copy_on_stream_); + if (err != nullptr) { + for (size_t i = input.start_request_idx_; i <= input.end_request_idx_; + ++i) { + RESPOND_AND_SET_NULL_IF_ERROR( + &(*responses_)[i], + TRITONSERVER_ErrorNew( + TRITONSERVER_ErrorCode(err), TRITONSERVER_ErrorMessage(err))); + } + TRITONSERVER_ErrorDelete(err); + } + cuda_copy |= cuda_used; + return cuda_copy; +} + +bool +BackendInputCollector::FlushPendingPinned( + char* tensor_buffer, const size_t tensor_buffer_byte_size, + const TRITONSERVER_MemoryType tensor_memory_type, + const int64_t tensor_memory_type_id) +{ + bool cuda_copy = false; + + // Will be copying from CPU->pinned->GPU or GPU->pinned->CPU + + // Attempt to allocate a pinned buffer to use for staging the + // copy... if we fail to allocated the pinned buffer then we just + // directly go CPU->GPU or GPU->CPU. + char* pinned_memory = nullptr; + int64_t pinned_memory_type_id = 0; + TRITONSERVER_MemoryType pinned_memory_type; + BackendMemory* backend_memory; + if (pending_pinned_byte_size_ > 0) { + TRITONSERVER_Error* err = BackendMemory::Create( + memory_manager_, + {BackendMemory::AllocationType::CPU_PINNED_POOL, + BackendMemory::AllocationType::CPU_PINNED}, + 0 /* memory_type_id */, pending_pinned_byte_size_, &backend_memory); + if (err != nullptr) { + TRITONSERVER_ErrorDelete(err); + } else { + pinned_memory = backend_memory->MemoryPtr(); + pinned_memory_type = backend_memory->MemoryType(); + pinned_memory_type_id = backend_memory->MemoryTypeId(); + } + } + + // If the pinned buffer wasn't actually allocated then just perform + // a direct copy. + if (pinned_memory == nullptr) { + size_t offset = 0; + for (auto& pr : pending_pinned_input_buffers_) { + cuda_copy |= SetInputTensor( + "pinned fallback", pr, tensor_buffer, tensor_buffer_byte_size, + tensor_memory_type, tensor_memory_type_id, + pending_pinned_offset_ + offset, TRITONSERVER_MEMORY_CPU_PINNED, + false, true); + offset += pr.memory_desc_.byte_size_; + } + } + // We have a pinned buffer so copy the pending input buffer(s) into + // the pinned memory. + else { // pinned_memory_type == TRITONSERVER_MEMORY_CPU_PINNED + bool cuda_used = false; + size_t offset = 0; + if (!use_async_cpu_copy_) { + for (auto& pr : pending_pinned_input_buffers_) { + cuda_used |= SetInputTensor( + "pinned H2H", pr, pinned_memory, pending_pinned_byte_size_, + TRITONSERVER_MEMORY_CPU_PINNED, 0 /* memory_type_id */, offset, + TRITONSERVER_MEMORY_CPU_PINNED, false, true); + offset += pr.memory_desc_.byte_size_; + } + + cuda_copy |= cuda_used; + + // If the copy was not async (i.e. if request input was in CPU so + // a CPU->CPU-PINNED copy was performed above), then the pinned + // buffer now holds the tensor contents and we can immediately + // issue the copies from the pinned buffer to the tensor. + // + // Otherwise the GPU->CPU-PINNED async copies are in flight and we + // simply remember the pinned buffer and the corresponding + // request inputs so that we can do the pinned->CPU copies in + // finalize after we have waited for all async copies to complete. + if (!cuda_used) { +#ifdef TRITON_ENABLE_GPU + if (buffer_ready_event_ != nullptr) { + cudaEventSynchronize(buffer_ready_event_); + buffer_ready_event_ = nullptr; + } +#endif // TRITON_ENABLE_GPU + auto err = CopyBuffer( + "pinned input buffer H2D", TRITONSERVER_MEMORY_CPU_PINNED, + 0 /* memory_type_id */, tensor_memory_type, tensor_memory_type_id, + pending_pinned_byte_size_, pinned_memory, + tensor_buffer + pending_pinned_offset_, stream_, &cuda_used, + copy_on_stream_); + cuda_copy |= cuda_used; + + // If something goes wrong with the copy all the pending + // responses fail... + if (err != nullptr) { + for (auto& pr : pending_pinned_input_buffers_) { + for (size_t idx = pr.start_request_idx_; idx <= pr.end_request_idx_; + ++idx) { + if ((*responses_)[idx] != nullptr) { + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend( + (*responses_)[idx], + TRITONSERVER_RESPONSE_COMPLETE_FINAL, err), + "failed to send error response"); + (*responses_)[idx] = nullptr; + } + } + } + TRITONSERVER_ErrorDelete(err); + } + } else { // cuda_used + deferred_pinned_.emplace_back( + pinned_memory, pending_pinned_byte_size_, tensor_buffer, + pending_pinned_offset_, tensor_memory_type, tensor_memory_type_id, + std::move(pending_pinned_input_buffers_), responses_); + } + } else { + async_task_count_++; + deferred_pinned_.emplace_back( + pinned_memory, pending_pinned_byte_size_, tensor_buffer, + pending_pinned_offset_, tensor_memory_type, tensor_memory_type_id, + std::move(pending_pinned_input_buffers_), responses_); + auto& deferred_pinned = deferred_pinned_.back(); + // Mark finalized to avoid duplicated call to DeferredPinned::Finalized() + // in BackendInputCollector::Finalize() + deferred_pinned_.back().finalized_ = true; + auto incomplete_count = new std::atomic(std::min( + deferred_pinned_.back().requests_.size(), + triton::common::AsyncWorkQueue::WorkerCount())); + auto pending_pinned_byte_size = pending_pinned_byte_size_; + size_t stride = (deferred_pinned_.back().requests_.size() + + triton::common::AsyncWorkQueue::WorkerCount() - 1) / + triton::common::AsyncWorkQueue::WorkerCount(); + auto pending_it = deferred_pinned_.back().requests_.begin(); + while (pending_it != deferred_pinned_.back().requests_.end()) { + auto end_it = pending_it; + auto next_offset = offset; + for (size_t idx = 0; idx < stride; idx++) { + next_offset += end_it->memory_desc_.byte_size_; + end_it++; + if (end_it == deferred_pinned_.back().requests_.end()) { + break; + } + } + + auto err = + CommonErrorToTritonError(triton::common::AsyncWorkQueue::AddTask( + [this, offset, pinned_memory, pinned_memory_type, + pending_pinned_byte_size, pinned_memory_type_id, pending_it, + end_it, incomplete_count, &deferred_pinned]() mutable { + for (; pending_it != end_it; pending_it++) { + SetInputTensor( + "pinned async H2H", *pending_it, pinned_memory, + pending_pinned_byte_size, pinned_memory_type, + pinned_memory_type_id, offset, + TRITONSERVER_MEMORY_CPU_PINNED, false, false); + offset += pending_it->memory_desc_.byte_size_; + } + // The last segmented task will start the next phase of + // the internal pinned buffer copy + if (incomplete_count->fetch_sub(1) == 1) { +#ifdef TRITON_ENABLE_GPU + if (buffer_ready_event_ != nullptr) { + cudaEventSynchronize(buffer_ready_event_); + buffer_ready_event_ = nullptr; + } +#endif // TRITON_ENABLE_GPU + completion_queue_.Put(deferred_pinned.Finalize(stream_)); + delete incomplete_count; + } + })); + if (err != nullptr) { + for (; pending_it != end_it; pending_it++) { + for (size_t idx = pending_it->start_request_idx_; + idx <= pending_it->end_request_idx_; ++idx) { + if ((*responses_)[idx] != nullptr) { + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend( + (*responses_)[idx], + TRITONSERVER_RESPONSE_COMPLETE_FINAL, err), + "failed to send error response"); + (*responses_)[idx] = nullptr; + } + } + } + } + TRITONSERVER_ErrorDelete(err); + + offset = next_offset; + pending_it = end_it; + } + } + } + + // Pending pinned copies are handled... + pending_pinned_byte_size_ = 0; + pending_pinned_offset_ = 0; + pending_pinned_input_buffers_.clear(); + + // Need to hold on to the allocated pinned buffer as there are still + // copies in flight. Will delete it in finalize. + if (pinned_memory != nullptr) { + in_use_memories_.emplace_back(backend_memory); + } + + return cuda_copy; +} + +TRITONSERVER_Error* +BackendInputCollector::BatchInputShape( + const BatchInput& batch_input, std::vector* shape) +{ + *shape = std::vector{0}; + switch (batch_input.BatchInputKind()) { + case BatchInput::Kind::BATCH_ELEMENT_COUNT: + case BatchInput::Kind::BATCH_ACCUMULATED_ELEMENT_COUNT: { + (*shape)[0] = request_count_; + break; + } + case BatchInput::Kind::BATCH_ACCUMULATED_ELEMENT_COUNT_WITH_ZERO: { + (*shape)[0] = request_count_ + 1; + break; + } + case BatchInput::Kind::BATCH_MAX_ELEMENT_COUNT_AS_SHAPE: { + const auto& source_input = batch_input.SourceInputs()[0]; + for (size_t req_idx = 0; req_idx < request_count_; req_idx++) { + TRITONBACKEND_Input* input; + RETURN_IF_ERROR(TRITONBACKEND_RequestInput( + requests_[req_idx], source_input.c_str(), &input)); + const int64_t* shape_arr; + uint32_t dims_count; + RETURN_IF_ERROR(TRITONBACKEND_InputPropertiesForHostPolicy( + input, host_policy_cstr_, nullptr, nullptr, &shape_arr, &dims_count, + nullptr, nullptr)); + (*shape)[0] = + std::max((*shape)[0], GetElementCount(shape_arr, dims_count)); + } + break; + } + case BatchInput::Kind::BATCH_ITEM_SHAPE: { + shape->emplace_back(0); + const auto& source_input = batch_input.SourceInputs()[0]; + for (size_t req_idx = 0; req_idx < request_count_; req_idx++) { + TRITONBACKEND_Input* input; + RETURN_IF_ERROR(TRITONBACKEND_RequestInput( + requests_[req_idx], source_input.c_str(), &input)); + const int64_t* shape_arr; + uint32_t dims_count; + RETURN_IF_ERROR(TRITONBACKEND_InputPropertiesForHostPolicy( + input, host_policy_cstr_, nullptr, nullptr, &shape_arr, &dims_count, + nullptr, nullptr)); + // Assuming first dimension is batch size and ragged input is only set + // for batching enabled model. + (*shape)[0] += shape_arr[0]; + // The batch input tracks the shape without batch dimension for + // each batch item + (*shape)[1] = (dims_count - 1); + } + break; + } + case BatchInput::Kind::BATCH_ITEM_SHAPE_FLATTEN: { + const auto& source_input = batch_input.SourceInputs()[0]; + for (size_t req_idx = 0; req_idx < request_count_; req_idx++) { + TRITONBACKEND_Input* input; + RETURN_IF_ERROR(TRITONBACKEND_RequestInput( + requests_[req_idx], source_input.c_str(), &input)); + const int64_t* shape_arr; + uint32_t dims_count; + RETURN_IF_ERROR(TRITONBACKEND_InputPropertiesForHostPolicy( + input, host_policy_cstr_, nullptr, nullptr, &shape_arr, &dims_count, + nullptr, nullptr)); + // Assuming first dimension is batch size and ragged input is only set + // for batching enabled model. + // The batch input tracks the shape without batch dimension for + // each batch item + (*shape)[0] += (shape_arr[0] * (dims_count - 1)); + } + break; + } + default: + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, "unsupported BatchInputKind received"); + } + return nullptr; // success +} + +TRITONSERVER_Error* +BackendInputCollector::ProcessBatchInput( + const BatchInput& batch_input, char* buffer, const size_t buffer_byte_size, + const std::vector>& + allowed_input_types, + const char** dst_buffer, size_t* dst_buffer_byte_size, + TRITONSERVER_MemoryType* dst_memory_type, int64_t* dst_memory_type_id) +{ +#ifdef TRITON_ENABLE_GPU + if (buffer_ready_event_ != nullptr) { + cudaEventSynchronize(buffer_ready_event_); + buffer_ready_event_ = nullptr; + } +#endif // TRITON_ENABLE_GPU + if (buffer == nullptr) { + if (allowed_input_types.size() == 0) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "'allowed_input_types' must contain at least one pair of memory type " + "and id"); + } + // Calculate the byte size of the buffer + std::vector shape; + RETURN_IF_ERROR(BatchInputShape(batch_input, &shape)); + *dst_buffer_byte_size = GetByteSize(batch_input.DataType(), shape); + BackendMemory* backend_memory = nullptr; + for (const auto& allowed_type : allowed_input_types) { + std::vector alloc_types; + const int64_t memory_type_id = allowed_type.second; + switch (allowed_type.first) { + case TRITONSERVER_MEMORY_GPU: + alloc_types = {BackendMemory::AllocationType::GPU_POOL, + BackendMemory::AllocationType::GPU}; + break; + case TRITONSERVER_MEMORY_CPU_PINNED: + alloc_types = {BackendMemory::AllocationType::CPU_PINNED_POOL, + BackendMemory::AllocationType::CPU_PINNED}; + break; + case TRITONSERVER_MEMORY_CPU: + alloc_types = {BackendMemory::AllocationType::CPU}; + break; + } + auto err = BackendMemory::Create( + memory_manager_, alloc_types, memory_type_id, *dst_buffer_byte_size, + &backend_memory); + if (err != nullptr) { + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("unable to create backend memory for type: ") + + TRITONSERVER_MemoryTypeString(allowed_type.first) + + " id: " + std::to_string(memory_type_id) + ": " + + TRITONSERVER_ErrorMessage(err)) + .c_str()); + TRITONSERVER_ErrorDelete(err); + } else { + in_use_memories_.emplace_back(backend_memory); + break; + } + } + if (backend_memory == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string( + "failed to allocate contiguous buffer for batch input '") + + batch_input.TargetNames()[0] + "'") + .c_str()); + } + buffer = backend_memory->MemoryPtr(); + *dst_buffer = backend_memory->MemoryPtr(); + *dst_buffer_byte_size = backend_memory->ByteSize(); + *dst_memory_type = backend_memory->MemoryType(); + *dst_memory_type_id = backend_memory->MemoryTypeId(); + } else { + if (allowed_input_types.size() != 1) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "'allowed_input_types' must only contain the memory type and id of " + "'buffer'"); + } + *dst_buffer = buffer; + *dst_buffer_byte_size = buffer_byte_size; + *dst_memory_type = allowed_input_types[0].first; + *dst_memory_type_id = allowed_input_types[0].second; + } + + char* input_buffer = buffer; + std::unique_ptr internal_buffer; + // Need a CPU buffer for modifying the value + if (*dst_memory_type == TRITONSERVER_MEMORY_GPU) { + BackendMemory* ib = nullptr; + RETURN_IF_ERROR(BackendMemory::Create( + memory_manager_, + {BackendMemory::AllocationType::CPU_PINNED_POOL, + BackendMemory::AllocationType::CPU}, + 0, *dst_buffer_byte_size, &ib)); + internal_buffer.reset(ib); + input_buffer = internal_buffer->MemoryPtr(); + } + const auto& data_type = batch_input.DataType(); + switch (batch_input.BatchInputKind()) { + case BatchInput::Kind::BATCH_ELEMENT_COUNT: { + const auto& source_input = batch_input.SourceInputs()[0]; + if (data_type == TRITONSERVER_TYPE_FP32) { + RETURN_IF_ERROR(SetElementCount( + source_input, input_buffer, *dst_buffer_byte_size)); + } else { + RETURN_IF_ERROR(SetElementCount( + source_input, input_buffer, *dst_buffer_byte_size)); + } + break; + } + case BatchInput::Kind::BATCH_ACCUMULATED_ELEMENT_COUNT: { + const auto& source_input = batch_input.SourceInputs()[0]; + if (data_type == TRITONSERVER_TYPE_FP32) { + RETURN_IF_ERROR(SetAccumulatedElementCount( + source_input, input_buffer, *dst_buffer_byte_size)); + } else { + RETURN_IF_ERROR(SetAccumulatedElementCount( + source_input, input_buffer, *dst_buffer_byte_size)); + } + break; + } + case BatchInput::Kind::BATCH_ACCUMULATED_ELEMENT_COUNT_WITH_ZERO: { + const auto& source_input = batch_input.SourceInputs()[0]; + if (data_type == TRITONSERVER_TYPE_FP32) { + *reinterpret_cast(input_buffer) = 0; + RETURN_IF_ERROR(SetAccumulatedElementCount( + source_input, input_buffer + sizeof(float), + *dst_buffer_byte_size - sizeof(float))); + } else { + *reinterpret_cast(input_buffer) = 0; + RETURN_IF_ERROR(SetAccumulatedElementCount( + source_input, input_buffer + sizeof(int32_t), + *dst_buffer_byte_size - sizeof(int32_t))); + } + break; + } + case BatchInput::Kind::BATCH_MAX_ELEMENT_COUNT_AS_SHAPE: { + // The batch input is described by the shape, + // no data modification is needed + return nullptr; // success + } + case BatchInput::Kind::BATCH_ITEM_SHAPE: + case BatchInput::Kind::BATCH_ITEM_SHAPE_FLATTEN: { + // Use the same utilities for both types as the data will be the same, + // only difference is the shape of the tensor. + const auto& source_input = batch_input.SourceInputs()[0]; + if (data_type == TRITONSERVER_TYPE_FP32) { + *reinterpret_cast(input_buffer) = 0; + RETURN_IF_ERROR(SetBatchItemShape( + source_input, input_buffer, *dst_buffer_byte_size)); + } else { + *reinterpret_cast(input_buffer) = 0; + RETURN_IF_ERROR(SetBatchItemShape( + source_input, input_buffer, *dst_buffer_byte_size)); + } + break; + } + } + if (*dst_memory_type == TRITONSERVER_MEMORY_GPU) { + bool cuda_used; + RETURN_IF_ERROR(CopyBuffer( + "batch input buffer", internal_buffer->MemoryType(), + internal_buffer->MemoryTypeId(), *dst_memory_type, *dst_memory_type_id, + *dst_buffer_byte_size, input_buffer, buffer, stream_, &cuda_used, + copy_on_stream_)); + // Need to keep the backend memory alive in the case of async copy + in_use_memories_.emplace_back(std::move(internal_buffer)); + need_sync_ |= cuda_used; + } + return nullptr; // success +} + +template +TRITONSERVER_Error* +BackendInputCollector::SetElementCount( + const std::string& source_input, char* buffer, + const size_t buffer_byte_size) +{ + size_t buffer_offset = 0; + for (size_t req_idx = 0; req_idx < request_count_; req_idx++) { + if (buffer_offset + sizeof(T) > buffer_byte_size) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "unexpected total byte size for batch input"); + } + + TRITONBACKEND_Input* input; + RETURN_IF_ERROR(TRITONBACKEND_RequestInput( + requests_[req_idx], source_input.c_str(), &input)); + const int64_t* shape; + uint32_t dims_count; + RETURN_IF_ERROR(TRITONBACKEND_InputPropertiesForHostPolicy( + input, host_policy_cstr_, nullptr, nullptr, &shape, &dims_count, + nullptr, nullptr)); + *(reinterpret_cast(buffer) + req_idx) = + GetElementCount(shape, dims_count); + buffer_offset += sizeof(T); + } + // Set the rest of the buffer to 0 + for (; buffer_offset + sizeof(T) <= buffer_byte_size; + buffer_offset += sizeof(T)) { + *reinterpret_cast(buffer + buffer_offset) = 0; + } + return nullptr; // success +} + +template +TRITONSERVER_Error* +BackendInputCollector::SetAccumulatedElementCount( + const std::string& source_input, char* buffer, + const size_t buffer_byte_size) +{ + size_t accumulated_element_count = 0; + size_t buffer_offset = 0; + for (size_t req_idx = 0; req_idx < request_count_; req_idx++) { + if (buffer_offset + sizeof(T) > buffer_byte_size) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "unexpected total byte size for batch input"); + } + + TRITONBACKEND_Input* input; + RETURN_IF_ERROR(TRITONBACKEND_RequestInput( + requests_[req_idx], source_input.c_str(), &input)); + const int64_t* shape; + uint32_t dims_count; + RETURN_IF_ERROR(TRITONBACKEND_InputPropertiesForHostPolicy( + input, host_policy_cstr_, nullptr, nullptr, &shape, &dims_count, + nullptr, nullptr)); + accumulated_element_count += GetElementCount(shape, dims_count); + *(reinterpret_cast(buffer) + req_idx) = accumulated_element_count; + buffer_offset += sizeof(T); + } + // Set the rest of the buffer to 'accumulated_element_count' + // (no increase in element count) + for (; buffer_offset + sizeof(T) <= buffer_byte_size; + buffer_offset += sizeof(T)) { + *reinterpret_cast(buffer + buffer_offset) = accumulated_element_count; + } + return nullptr; // success +} + +template +TRITONSERVER_Error* +BackendInputCollector::SetBatchItemShape( + const std::string& source_input, char* buffer, + const size_t buffer_byte_size) +{ + size_t buffer_offset = 0; + for (size_t req_idx = 0; req_idx < request_count_; req_idx++) { + TRITONBACKEND_Input* input; + RETURN_IF_ERROR(TRITONBACKEND_RequestInput( + requests_[req_idx], source_input.c_str(), &input)); + const int64_t* shape; + uint32_t dims_count; + RETURN_IF_ERROR(TRITONBACKEND_InputPropertiesForHostPolicy( + input, host_policy_cstr_, nullptr, nullptr, &shape, &dims_count, + nullptr, nullptr)); + // Assuming first dimension is batch size and ragged input is only set + // for batching enabled model. + size_t batch_1_size = sizeof(T) * (dims_count - 1); + if (buffer_offset + (size_t)shape[0] * batch_1_size > buffer_byte_size) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (GetRequestId(requests_[req_idx]) + + "unexpected total byte size for batch input") + .c_str()); + } + // The batch input tracks the shape without batch dimension for + // each batch item + for (size_t idx = 1; idx < dims_count; ++idx) { + // Need to set the element explicitly for type conversion + *(reinterpret_cast(buffer + buffer_offset) + (idx - 1)) = shape[idx]; + } + // memcpy the data repeatedly if the request has batch size > 1 + for (int64_t idx = 1; idx < shape[0]; ++idx) { + memcpy( + buffer + buffer_offset + idx * batch_1_size, buffer + buffer_offset, + batch_1_size); + } + buffer_offset += batch_1_size * (size_t)shape[0]; + } + return nullptr; // success +} + +bool +BackendInputCollector::FlushPendingCopyKernel( + char* tensor_buffer, const size_t tensor_buffer_byte_size, + const TRITONSERVER_MemoryType tensor_memory_type, + const int64_t tensor_memory_type_id) +{ + if (pending_copy_kernel_input_buffers_.size() == 0) { + return false; + } + + bool cuda_copy = false; + TRITONSERVER_Error* error = nullptr; + // Only try to launch kernel if buffer count is large enough for + // good GPU utilization + if (pending_copy_kernel_input_buffer_counts_ >= kernel_buffer_threshold_) { + error = LaunchCopyKernel( + tensor_buffer, tensor_buffer_byte_size, tensor_memory_type, + tensor_memory_type_id); + cuda_copy = (error == nullptr); + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("gather kernel launched with status: ") + + ((error == nullptr) ? "Success" : TRITONSERVER_ErrorMessage(error))) + .c_str()); + } + // If kernel can't be launched then just perform a direct copy. + if ((pending_copy_kernel_input_buffer_counts_ < kernel_buffer_threshold_) || + (error != nullptr)) { + size_t offset = 0; + for (auto& pr : pending_copy_kernel_input_buffers_) { + cuda_copy |= SetInputTensor( + "gather kernel fallback", pr, tensor_buffer, tensor_buffer_byte_size, + tensor_memory_type, tensor_memory_type_id, + pending_copy_kernel_buffer_offset_ + offset, + TRITONSERVER_MEMORY_CPU_PINNED, false, true); + offset += pr.memory_desc_.byte_size_; + } + } + TRITONSERVER_ErrorDelete(error); + + // Pending kernel copies are handled... + pending_copy_kernel_buffer_byte_size_ = 0; + pending_copy_kernel_buffer_offset_ = 0; + pending_copy_kernel_input_buffer_counts_ = 0; + pending_copy_kernel_input_buffers_.clear(); + + return cuda_copy; +} + +TRITONSERVER_Error* +BackendInputCollector::LaunchCopyKernel( + char* tensor_buffer, const size_t tensor_buffer_byte_size, + const TRITONSERVER_MemoryType tensor_memory_type, + const int64_t tensor_memory_type_id) +{ +#ifdef TRITON_ENABLE_GPU + input_ptr_buffer_host_.emplace_back(new std::vector()); + byte_size_buffer_host_.emplace_back(new std::vector()); + byte_size_offset_buffer_host_.emplace_back(new std::vector()); + + auto& input_ptr_buffer_host = *input_ptr_buffer_host_.back(); + auto& byte_size_buffer_host = *byte_size_buffer_host_.back(); + auto& byte_size_offset_buffer_host = *byte_size_offset_buffer_host_.back(); + + input_ptr_buffer_host.reserve(pending_copy_kernel_input_buffer_counts_); + byte_size_buffer_host.reserve(pending_copy_kernel_input_buffer_counts_); + byte_size_offset_buffer_host.reserve( + pending_copy_kernel_input_buffer_counts_); + + size_t byte_size_offset = 0; + for (const auto& response_input : pending_copy_kernel_input_buffers_) { + const auto& input = response_input.memory_desc_; + input_ptr_buffer_host.emplace_back( + const_cast(reinterpret_cast(input.buffer_))); + byte_size_buffer_host.emplace_back(input.byte_size_); + byte_size_offset_buffer_host.emplace_back(byte_size_offset); + byte_size_offset += input.byte_size_; + } + + BackendMemory* backend_memory = nullptr; + std::vector alloc_types; + switch (tensor_memory_type) { + case TRITONSERVER_MEMORY_GPU: + alloc_types = {BackendMemory::AllocationType::GPU_POOL, + BackendMemory::AllocationType::GPU}; + break; + case TRITONSERVER_MEMORY_CPU_PINNED: + alloc_types = {BackendMemory::AllocationType::CPU_PINNED_POOL, + BackendMemory::AllocationType::CPU_PINNED}; + break; + case TRITONSERVER_MEMORY_CPU: + alloc_types = {BackendMemory::AllocationType::CPU}; + break; + } + + // input_ptr_buffer + size_t input_ptr_buffer_byte_size = + pending_copy_kernel_input_buffer_counts_ * sizeof(int8_t*); + auto err = BackendMemory::Create( + memory_manager_, alloc_types, tensor_memory_type_id, + input_ptr_buffer_byte_size, &backend_memory); + if (err != nullptr) { + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("unable to create backend memory for type: ") + + TRITONSERVER_MemoryTypeString(tensor_memory_type) + + " id: " + std::to_string(tensor_memory_type_id) + ": " + + TRITONSERVER_ErrorMessage(err)) + .c_str()); + TRITONSERVER_ErrorDelete(err); + } else { + in_use_memories_.emplace_back(backend_memory); + } + if (backend_memory == nullptr || + (backend_memory->MemoryType() != tensor_memory_type) || + (backend_memory->MemoryTypeId() != tensor_memory_type_id)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "Failed to obtain memory buffer for copy kernel input"); + } + char* input_ptr_buffer = backend_memory->MemoryPtr(); + + // byte_size_buffer + size_t byte_size_buffer_byte_size = + pending_copy_kernel_input_buffer_counts_ * sizeof(size_t); + err = BackendMemory::Create( + memory_manager_, alloc_types, tensor_memory_type_id, + byte_size_buffer_byte_size, &backend_memory); + if (err != nullptr) { + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("unable to create backend memory for type: ") + + TRITONSERVER_MemoryTypeString(tensor_memory_type) + + " id: " + std::to_string(tensor_memory_type_id) + ": " + + TRITONSERVER_ErrorMessage(err)) + .c_str()); + TRITONSERVER_ErrorDelete(err); + } else { + in_use_memories_.emplace_back(backend_memory); + } + if (backend_memory == nullptr || + (backend_memory->MemoryType() != tensor_memory_type) || + (backend_memory->MemoryTypeId() != tensor_memory_type_id)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "Failed to obtain memory buffer for copy kernel input"); + } + char* byte_size_buffer = backend_memory->MemoryPtr(); + + // byte_size_offset_buffer + size_t byte_size_offset_buffer_byte_size = + pending_copy_kernel_input_buffer_counts_ * sizeof(size_t); + err = BackendMemory::Create( + memory_manager_, alloc_types, tensor_memory_type_id, + byte_size_offset_buffer_byte_size, &backend_memory); + if (err != nullptr) { + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("unable to create backend memory for type: ") + + TRITONSERVER_MemoryTypeString(tensor_memory_type) + + " id: " + std::to_string(tensor_memory_type_id) + ": " + + TRITONSERVER_ErrorMessage(err)) + .c_str()); + TRITONSERVER_ErrorDelete(err); + } else { + in_use_memories_.emplace_back(backend_memory); + } + if (backend_memory == nullptr || + (backend_memory->MemoryType() != tensor_memory_type) || + (backend_memory->MemoryTypeId() != tensor_memory_type_id)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "Failed to obtain memory buffer for copy kernel input"); + } + char* byte_size_offset_buffer = backend_memory->MemoryPtr(); + + cudaMemcpyAsync( + input_ptr_buffer, input_ptr_buffer_host.data(), + pending_copy_kernel_input_buffer_counts_ * sizeof(int8_t*), + cudaMemcpyDefault, stream_); + cudaMemcpyAsync( + byte_size_buffer, byte_size_buffer_host.data(), + pending_copy_kernel_input_buffer_counts_ * sizeof(size_t), + cudaMemcpyDefault, stream_); + cudaMemcpyAsync( + byte_size_offset_buffer, byte_size_offset_buffer_host.data(), + pending_copy_kernel_input_buffer_counts_ * sizeof(size_t), + cudaMemcpyDefault, stream_); + if (buffer_ready_event_ != nullptr) { + cudaEventSynchronize(buffer_ready_event_); + buffer_ready_event_ = nullptr; + } + RETURN_IF_CUDA_ERROR( + RunGatherKernel( + (const int8_t**)input_ptr_buffer, (const size_t*)byte_size_buffer, + (const size_t*)byte_size_offset_buffer, + (int8_t*)tensor_buffer + pending_copy_kernel_buffer_offset_, + pending_copy_kernel_input_buffer_counts_, stream_), + TRITONSERVER_ERROR_INTERNAL, + std::string("Failed to launch gather kernel")); + return nullptr; +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "Copy kernel can not be launched with TRITON_ENABLE_GPU=OFF"); +#endif // TRITON_ENABLE_GPU +} + +}} // namespace triton::backend diff --git a/3rdparty/backend-r22.12/src/backend_memory.cc b/3rdparty/backend-r22.12/src/backend_memory.cc new file mode 100644 index 0000000000000000000000000000000000000000..9dd1594552de171d2d51a44a38b1467f1b30cb89 --- /dev/null +++ b/3rdparty/backend-r22.12/src/backend_memory.cc @@ -0,0 +1,231 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "triton/backend/backend_memory.h" + +#include +#include "triton/backend/backend_common.h" + +namespace triton { namespace backend { + +TRITONSERVER_Error* +BackendMemory::Create( + TRITONBACKEND_MemoryManager* manager, const AllocationType alloc_type, + const int64_t memory_type_id, const size_t byte_size, BackendMemory** mem) +{ + *mem = nullptr; + + void* ptr = nullptr; + switch (alloc_type) { + case AllocationType::CPU_PINNED: { +#ifdef TRITON_ENABLE_GPU + RETURN_IF_CUDA_ERROR( + cudaHostAlloc(&ptr, byte_size, cudaHostAllocPortable), + TRITONSERVER_ERROR_UNAVAILABLE, + std::string("failed to allocate pinned system memory")); +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "pinned-memory allocation not supported"); +#endif // TRITON_ENABLE_GPU + break; + } + + case AllocationType::GPU: { +#ifdef TRITON_ENABLE_GPU + int current_device; + RETURN_IF_CUDA_ERROR( + cudaGetDevice(¤t_device), TRITONSERVER_ERROR_INTERNAL, + std::string("failed to get device")); + bool overridden = (current_device != memory_type_id); + if (overridden) { + RETURN_IF_CUDA_ERROR( + cudaSetDevice(memory_type_id), TRITONSERVER_ERROR_INTERNAL, + std::string("failed to set device")); + } + + auto err = cudaMalloc(&ptr, byte_size); + + if (overridden) { + LOG_IF_CUDA_ERROR( + cudaSetDevice(current_device), "failed to set CUDA device"); + } + + RETURN_ERROR_IF_FALSE( + err == cudaSuccess, TRITONSERVER_ERROR_UNAVAILABLE, + std::string("failed to allocate GPU memory: ") + + cudaGetErrorString(err)); +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "GPU allocation not supported"); +#endif // TRITON_ENABLE_GPU + break; + } + + case AllocationType::CPU: + case AllocationType::CPU_PINNED_POOL: + case AllocationType::GPU_POOL: + RETURN_IF_ERROR(TRITONBACKEND_MemoryManagerAllocate( + manager, &ptr, AllocTypeToMemoryType(alloc_type), memory_type_id, + byte_size)); + break; + } + + *mem = new BackendMemory( + manager, alloc_type, memory_type_id, reinterpret_cast(ptr), + byte_size); + + return nullptr; // success +} + +TRITONSERVER_Error* +BackendMemory::Create( + TRITONBACKEND_MemoryManager* manager, + const std::vector& alloc_types, + const int64_t memory_type_id, const size_t byte_size, BackendMemory** mem) +{ + *mem = nullptr; + RETURN_ERROR_IF_TRUE( + alloc_types.size() == 0, TRITONSERVER_ERROR_INVALID_ARG, + std::string("BackendMemory::Create, at least one allocation type must be " + "specified")); + + bool success = false; + std::unordered_map errors; + for (const AllocationType alloc_type : alloc_types) { + TRITONSERVER_Error* err = + Create(manager, alloc_type, memory_type_id, byte_size, mem); + if (err == nullptr) { + success = true; + break; + } + + errors.insert({alloc_type, err}); + } + + // If allocation failed for all allocation types then display all + // the error messages and show the entire allocation request as + // failing. + if (!success) { + std::string msg = "BackendMemory::Create, all allocation types failed:"; + for (const auto& pr : errors) { + const AllocationType alloc_type = pr.first; + TRITONSERVER_Error* err = pr.second; + msg += std::string("\n\t") + AllocTypeString(alloc_type) + ": " + + TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } + + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_UNAVAILABLE, msg.c_str()); + } + + return nullptr; // success +} + +TRITONSERVER_Error* +BackendMemory::Create( + TRITONBACKEND_MemoryManager* manager, const AllocationType alloc_type, + const int64_t memory_type_id, void* buffer, const size_t byte_size, + BackendMemory** mem) +{ + *mem = new BackendMemory( + manager, alloc_type, memory_type_id, reinterpret_cast(buffer), + byte_size, false /* owns_buffer */); + + return nullptr; // success +} + +BackendMemory::~BackendMemory() +{ + if (owns_buffer_) { + switch (alloctype_) { + case AllocationType::CPU_PINNED: +#ifdef TRITON_ENABLE_GPU + if (buffer_ != nullptr) { + LOG_IF_CUDA_ERROR( + cudaFreeHost(buffer_), "failed to free pinned memory"); + } +#endif // TRITON_ENABLE_GPU + break; + + case AllocationType::GPU: +#ifdef TRITON_ENABLE_GPU + if (buffer_ != nullptr) { + LOG_IF_CUDA_ERROR(cudaFree(buffer_), "failed to free CUDA memory"); + } +#endif // TRITON_ENABLE_GPU + break; + + case AllocationType::CPU: + case AllocationType::CPU_PINNED_POOL: + case AllocationType::GPU_POOL: + LOG_IF_ERROR( + TRITONBACKEND_MemoryManagerFree( + manager_, buffer_, AllocTypeToMemoryType(alloctype_), + memtype_id_), + "failed to free memory buffer"); + break; + } + } +} + +TRITONSERVER_MemoryType +BackendMemory::AllocTypeToMemoryType(const AllocationType a) +{ + switch (a) { + case AllocationType::CPU: + return TRITONSERVER_MEMORY_CPU; + case AllocationType::CPU_PINNED: + case AllocationType::CPU_PINNED_POOL: + return TRITONSERVER_MEMORY_CPU_PINNED; + case AllocationType::GPU: + case AllocationType::GPU_POOL: + return TRITONSERVER_MEMORY_GPU; + } + + return TRITONSERVER_MEMORY_CPU; // unreachable +} + +const char* +BackendMemory::AllocTypeString(const AllocationType a) +{ + switch (a) { + case AllocationType::CPU: + return "CPU"; + case AllocationType::CPU_PINNED: + return "CPU_PINNED"; + case AllocationType::GPU: + return "GPU"; + case AllocationType::CPU_PINNED_POOL: + return "CPU_PINNED_POOL"; + case AllocationType::GPU_POOL: + return "GPU_POOL"; + } + + return ""; +} + +}} // namespace triton::backend diff --git a/3rdparty/backend-r22.12/src/backend_model.cc b/3rdparty/backend-r22.12/src/backend_model.cc new file mode 100644 index 0000000000000000000000000000000000000000..1859580d3eb527c8cbd39e0418c1e0999508f67d --- /dev/null +++ b/3rdparty/backend-r22.12/src/backend_model.cc @@ -0,0 +1,192 @@ +// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "triton/backend/backend_model.h" + +#include "triton/backend/backend_common.h" + +namespace triton { namespace backend { + +// +// BackendModel +// +BackendModel::BackendModel( + TRITONBACKEND_Model* triton_model, const bool allow_optional) + : triton_model_(triton_model), allow_optional_(allow_optional) +{ + const char* model_name; + THROW_IF_BACKEND_MODEL_ERROR( + TRITONBACKEND_ModelName(triton_model, &model_name)); + name_ = model_name; + + THROW_IF_BACKEND_MODEL_ERROR( + TRITONBACKEND_ModelVersion(triton_model, &version_)); + + const char* repository_path = nullptr; + TRITONBACKEND_ArtifactType repository_artifact_type; + THROW_IF_BACKEND_MODEL_ERROR(TRITONBACKEND_ModelRepository( + triton_model, &repository_artifact_type, &repository_path)); + if (repository_artifact_type != TRITONBACKEND_ARTIFACT_FILESYSTEM) { + throw BackendModelException(TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + (std::string("unsupported repository artifact type for model '") + + model_name + "'") + .c_str())); + } + repository_path_ = repository_path; + + THROW_IF_BACKEND_MODEL_ERROR( + TRITONBACKEND_ModelServer(triton_model, &triton_server_)); + TRITONBACKEND_Backend* backend; + THROW_IF_BACKEND_MODEL_ERROR( + TRITONBACKEND_ModelBackend(triton_model, &backend)); + THROW_IF_BACKEND_MODEL_ERROR( + TRITONBACKEND_BackendMemoryManager(backend, &triton_memory_manager_)); + + THROW_IF_BACKEND_MODEL_ERROR(ParseModelConfig()); +} + +TRITONSERVER_Error* +BackendModel::ParseModelConfig() +{ + TRITONSERVER_Message* config_message; + RETURN_IF_ERROR(TRITONBACKEND_ModelConfig( + triton_model_, 1 /* config_version */, &config_message)); + + // Get the model configuration as a json string from + // config_message. We use TritonJson, which is a wrapper that + // returns nice errors (currently the underlying implementation is + // rapidjson... but others could be added). + const char* buffer; + size_t byte_size; + RETURN_IF_ERROR( + TRITONSERVER_MessageSerializeToJson(config_message, &buffer, &byte_size)); + + TRITONSERVER_Error* err = model_config_.Parse(buffer, byte_size); + RETURN_IF_ERROR(TRITONSERVER_MessageDelete(config_message)); + RETURN_IF_ERROR(err); + + int64_t mbs = 0; + RETURN_IF_ERROR(model_config_.MemberAsInt("max_batch_size", &mbs)); + max_batch_size_ = mbs; + + enable_pinned_input_ = false; + enable_pinned_output_ = false; + { + common::TritonJson::Value optimization; + if (model_config_.Find("optimization", &optimization)) { + common::TritonJson::Value pinned_memory; + if (optimization.Find("input_pinned_memory", &pinned_memory)) { + RETURN_IF_ERROR( + pinned_memory.MemberAsBool("enable", &enable_pinned_input_)); + } + if (optimization.Find("output_pinned_memory", &pinned_memory)) { + RETURN_IF_ERROR( + pinned_memory.MemberAsBool("enable", &enable_pinned_output_)); + } + } + } + + RETURN_IF_ERROR( + BatchInput::ParseFromModelConfig(model_config_, &batch_inputs_)); + RETURN_IF_ERROR( + BatchOutput::ParseFromModelConfig(model_config_, &batch_outputs_)); + for (const auto& batch_output : batch_outputs_) { + for (const auto& name : batch_output.TargetNames()) { + batch_output_map_.emplace(name, &batch_output); + } + } + triton::common::TritonJson::Value config_inputs; + RETURN_IF_ERROR(model_config_.MemberAsArray("input", &config_inputs)); + for (size_t i = 0; i < config_inputs.ArraySize(); i++) { + triton::common::TritonJson::Value io; + RETURN_IF_ERROR(config_inputs.IndexAsObject(i, &io)); + std::string io_name; + RETURN_IF_ERROR(io.MemberAsString("name", &io_name)); + triton::common::TritonJson::Value input_property_json; + bool allow_ragged_batch = false; + if (io.Find("allow_ragged_batch", &input_property_json)) { + RETURN_IF_ERROR(input_property_json.AsBool(&allow_ragged_batch)); + } + if (allow_ragged_batch) { + ragged_inputs_.emplace(io_name); + } + bool optional = false; + if (io.Find("optional", &input_property_json)) { + RETURN_IF_ERROR(input_property_json.AsBool(&optional)); + } + if (optional) { + if (allow_optional_) { + optional_inputs_.emplace(io_name); + } else { + RETURN_IF_ERROR(TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("'optional' is set to true for input '") + io_name + + "' while the backend model doesn't support optional input") + .c_str())); + } + } + } + + return nullptr; +} + +TRITONSERVER_Error* +BackendModel::SetModelConfig() +{ + triton::common::TritonJson::WriteBuffer json_buffer; + RETURN_IF_ERROR(ModelConfig().Write(&json_buffer)); + + TRITONSERVER_Message* message; + RETURN_IF_ERROR(TRITONSERVER_MessageNewFromSerializedJson( + &message, json_buffer.Base(), json_buffer.Size())); + RETURN_IF_ERROR(TRITONBACKEND_ModelSetConfig( + triton_model_, 1 /* config_version */, message)); + RETURN_IF_ERROR(TRITONSERVER_MessageDelete(message)); + + // Triton core can normalize the missing config settings + // in the above call. We must retrieve the updated model + // configration from the core. + RETURN_IF_ERROR(ParseModelConfig()); + + return nullptr; +} + +TRITONSERVER_Error* +BackendModel::SupportsFirstDimBatching(bool* supports) +{ + *supports = max_batch_size_ > 0; + return nullptr; +} + +const BatchOutput* +BackendModel::FindBatchOutput(const std::string& output_name) const +{ + const auto it = batch_output_map_.find(output_name); + return ((it == batch_output_map_.end()) ? nullptr : it->second); +} + +}} // namespace triton::backend diff --git a/3rdparty/backend-r22.12/src/backend_model_instance.cc b/3rdparty/backend-r22.12/src/backend_model_instance.cc new file mode 100644 index 0000000000000000000000000000000000000000..ae7ff9d71f9d3cecbca73c91b8cfe524e4871ba5 --- /dev/null +++ b/3rdparty/backend-r22.12/src/backend_model_instance.cc @@ -0,0 +1,171 @@ +// Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "triton/backend/backend_model_instance.h" + +#include +#include "triton/backend/backend_common.h" +#include "triton/backend/backend_model.h" + +namespace triton { namespace backend { + +// +// BackendModelInstance +// +BackendModelInstance::BackendModelInstance( + BackendModel* backend_model, + TRITONBACKEND_ModelInstance* triton_model_instance) + : backend_model_(backend_model), + triton_model_instance_(triton_model_instance) +{ + const char* instance_name; + THROW_IF_BACKEND_INSTANCE_ERROR( + TRITONBACKEND_ModelInstanceName(triton_model_instance, &instance_name)); + name_ = instance_name; + + THROW_IF_BACKEND_INSTANCE_ERROR( + TRITONBACKEND_ModelInstanceKind(triton_model_instance, &kind_)); + + THROW_IF_BACKEND_INSTANCE_ERROR( + TRITONBACKEND_ModelInstanceDeviceId(triton_model_instance, &device_id_)); + + common::TritonJson::Value& model_config = backend_model->ModelConfig(); + + // If the model configuration specifies a 'default_model_filename' + // and/or specifies 'cc_model_filenames' then determine the + // appropriate 'artifact_filename' value. If model configuration + // does not specify then just leave 'artifact_filename' empty and + // the backend can then provide its own logic for determine the + // filename if that is appropriate. + THROW_IF_BACKEND_INSTANCE_ERROR(model_config.MemberAsString( + "default_model_filename", &artifact_filename_)); + + switch (kind_) { + case TRITONSERVER_INSTANCEGROUPKIND_CPU: { + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("Creating instance ") + name_ + + " on CPU using artifact '" + artifact_filename_ + "'") + .c_str()); + break; + } + case TRITONSERVER_INSTANCEGROUPKIND_MODEL: { + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("Creating instance ") + name_ + + " on model-specified devices using artifact '" + artifact_filename_ + + "'") + .c_str()); + break; + } + case TRITONSERVER_INSTANCEGROUPKIND_GPU: { +#if defined(TRITON_ENABLE_GPU) + cudaDeviceProp cuprops; + cudaError_t cuerr = cudaGetDeviceProperties(&cuprops, device_id_); + if (cuerr != cudaSuccess) { + throw BackendModelInstanceException(TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("unable to get CUDA device properties for ") + name_ + + ": " + cudaGetErrorString(cuerr)) + .c_str())); + } + + const std::string cc = + std::to_string(cuprops.major) + "." + std::to_string(cuprops.minor); + common::TritonJson::Value cc_names; + common::TritonJson::Value cc_name; + if ((model_config.Find("cc_model_filenames", &cc_names)) && + (cc_names.Find(cc.c_str(), &cc_name))) { + cc_name.AsString(&artifact_filename_); + } + + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("Creating instance ") + name_ + " on GPU " + + std::to_string(device_id_) + " (" + cc + ") using artifact '" + + artifact_filename_ + "'") + .c_str()); +#elif !defined(TRITON_ENABLE_MALI_GPU) + throw BackendModelInstanceException(TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, "GPU instances not supported")); +#endif // TRITON_ENABLE_GPU + break; + } + default: { + throw BackendModelInstanceException(TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("unexpected instance kind for ") + name_).c_str())); + } + } + + stream_ = nullptr; + if (kind_ == TRITONSERVER_INSTANCEGROUPKIND_GPU) { + THROW_IF_BACKEND_INSTANCE_ERROR( + CreateCudaStream(device_id_, 0 /* cuda_stream_priority */, &stream_)); + } + + // Get the host policy setting as a json string from message, + // and extract the host policy name for the instance. + TRITONSERVER_Message* message = nullptr; + THROW_IF_BACKEND_MODEL_ERROR( + TRITONBACKEND_ModelInstanceHostPolicy(triton_model_instance_, &message)); + const char* buffer; + size_t byte_size; + THROW_IF_BACKEND_MODEL_ERROR( + TRITONSERVER_MessageSerializeToJson(message, &buffer, &byte_size)); + + common::TritonJson::Value host_policy; + TRITONSERVER_Error* err = host_policy.Parse(buffer, byte_size); + THROW_IF_BACKEND_MODEL_ERROR(err); + std::vector host_policy_name; + THROW_IF_BACKEND_MODEL_ERROR(host_policy.Members(&host_policy_name)); + if (host_policy_name.size() != 1) { + throw BackendModelInstanceException(TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("unexpected no host policy for ") + name_).c_str())); + } + host_policy_name_ = host_policy_name[0]; +} + + +BackendModelInstance::~BackendModelInstance() +{ +#ifdef TRITON_ENABLE_GPU + if (stream_ != nullptr) { + cudaError_t err = cudaStreamDestroy(stream_); + if (err != cudaSuccess) { + TRITONSERVER_LogMessage( + TRITONSERVER_LOG_ERROR, __FILE__, __LINE__, + (std::string("~BackendModelInstance: ") + name_ + + " failed to destroy cuda stream: " + cudaGetErrorString(err)) + .c_str()); + } + stream_ = nullptr; + } +#endif // TRITON_ENABLE_GPU +} + +}} // namespace triton::backend diff --git a/3rdparty/backend-r22.12/src/backend_output_responder.cc b/3rdparty/backend-r22.12/src/backend_output_responder.cc new file mode 100644 index 0000000000000000000000000000000000000000..81acd2517d1080850a39296f38d3e3bc435a5b0d --- /dev/null +++ b/3rdparty/backend-r22.12/src/backend_output_responder.cc @@ -0,0 +1,607 @@ +// Copyright 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "triton/backend/backend_output_responder.h" + +#include "triton/backend/backend_common.h" +#include "triton/backend/backend_model.h" +#include "triton/backend/backend_model_instance.h" + +namespace triton { namespace backend { + +// +// BackendOutputResponder +// +BackendOutputResponder::~BackendOutputResponder() +{ + for (auto& pinned_memory : pinned_memories_) { + LOG_IF_ERROR( + TRITONBACKEND_MemoryManagerFree( + memory_manager_, reinterpret_cast(pinned_memory), + TRITONSERVER_MEMORY_CPU_PINNED, 0), + "failed to free pinned memory"); + } +} + +void +BackendOutputResponder::ProcessTensor( + const std::string& output_name, const TRITONSERVER_DataType datatype, + std::vector& batchn_shape, const char* buffer, + const TRITONSERVER_MemoryType memory_type, const int64_t memory_type_id) +{ + // A value of CPU_PINNED indicates that pinned memory buffer is not + // needed for this tensor. Any other value indicates that a pinned + // memory buffer is needed when the target memory type matches + // 'use_pinned_memory_type'. + TRITONSERVER_MemoryType use_pinned_memory_type = + TRITONSERVER_MEMORY_CPU_PINNED; + if (pinned_enabled_) { + use_pinned_memory_type = GetUsePinnedMemoryType(memory_type); + } + + const int64_t batchn_batch_size = batchn_shape[0]; + int64_t batch_size_offset = 0; + + size_t tensor_offset = 0; + + for (size_t idx = 0; idx < responses_->size(); idx++) { + auto& request = requests_[idx]; + auto& response = (*responses_)[idx]; + + // If then pending copies are from tensor buffer that is not + // contiguous with 'response's part of that buffer, then need to + // go ahead and perform the pending copies so that can start a + // new contiguous region if necessary. + if ((pending_pinned_byte_size_ > 0) && + (tensor_offset != + (pending_pinned_byte_size_ + pending_pinned_offset_))) { + need_sync_ |= FlushPendingPinned(buffer, memory_type, memory_type_id); + } + + // Override shape to be correct for this response. + if (first_dim_batching_) { + TRITONBACKEND_Input* input; + TRITONBACKEND_RequestInputByIndex(request, 0, &input); + const int64_t* shape; + TRITONBACKEND_InputProperties( + input, nullptr, nullptr, &shape, nullptr, nullptr, nullptr); + if ((batchn_batch_size != -1) && + ((batch_size_offset + shape[0]) > batchn_batch_size)) { + if (response != nullptr) { + RESPOND_AND_SET_NULL_IF_ERROR( + &response, + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + std::string( + GetRequestId(request) + + "failed to split the output tensor '" + output_name + + "' in responses: expected batch size of atleast " + + std::to_string(batch_size_offset + shape[0]) + + " in model output, got " + + std::to_string(batchn_batch_size)) + .c_str())); + } + } + batchn_shape[0] = shape[0]; + batch_size_offset += shape[0]; + } + + const size_t tensor_byte_size = GetByteSize(datatype, batchn_shape); + + TRITONBACKEND_Output* response_output; + if (response != nullptr) { + uint32_t output_count; + RESPOND_AND_SET_NULL_IF_ERROR( + &response, TRITONBACKEND_RequestOutputCount(request, &output_count)); + if (response != nullptr) { + for (uint32_t output_idx = 0; output_idx < output_count; output_idx++) { + const char* name; + RESPOND_AND_SET_NULL_IF_ERROR( + &response, + TRITONBACKEND_RequestOutputName(request, output_idx, &name)); + if ((response != nullptr) && (output_name == name)) { + RESPOND_AND_SET_NULL_IF_ERROR( + &response, TRITONBACKEND_ResponseOutput( + response, &response_output, name, datatype, + batchn_shape.data(), batchn_shape.size())); + if (response != nullptr) { + need_sync_ |= SetFixedSizeBuffer( + &response, response_output, output_name, tensor_byte_size, + tensor_offset, buffer, memory_type, memory_type_id, + use_pinned_memory_type, false /* state */); + } + + break; + } + } + } + } + + tensor_offset += tensor_byte_size; + } + + // Done with the tensor, flush any pending pinned copies. + need_sync_ |= FlushPendingPinned(buffer, memory_type, memory_type_id); +#ifdef TRITON_ENABLE_GPU + if (need_sync_ && (event_ != nullptr)) { + cudaEventRecord(event_, stream_); + } +#endif // TRITON_ENABLE_GPU +} + +std::vector +BackendOutputResponder::ProcessStateTensor( + const std::string& output_state_name, const TRITONSERVER_DataType datatype, + std::vector& batchn_shape, const char* buffer, + const TRITONSERVER_MemoryType memory_type, const int64_t memory_type_id) +{ + // A value of CPU_PINNED indicates that pinned memory buffer is not + // needed for this tensor. Any other value indicates that a pinned + // memory buffer is needed when the target memory type matches + // 'use_pinned_memory_type'. + TRITONSERVER_MemoryType use_pinned_memory_type = + TRITONSERVER_MEMORY_CPU_PINNED; + if (pinned_enabled_) { + use_pinned_memory_type = GetUsePinnedMemoryType(memory_type); + } + + std::vector states; + + const int64_t batchn_batch_size = batchn_shape[0]; + int64_t batch_size_offset = 0; + + size_t tensor_offset = 0; + + for (size_t idx = 0; idx < responses_->size(); idx++) { + auto& request = requests_[idx]; + auto& response = (*responses_)[idx]; + + // If then pending copies are from tensor buffer that is not + // contiguous with 'response's part of that buffer, then need to + // go ahead and perform the pending copies so that can start a + // new contiguous region if necessary. + if ((pending_pinned_byte_size_ > 0) && + (tensor_offset != + (pending_pinned_byte_size_ + pending_pinned_offset_))) { + need_sync_ |= FlushPendingPinned(buffer, memory_type, memory_type_id); + } + + // Override shape to be correct for this response. + if (first_dim_batching_) { + TRITONBACKEND_Input* input; + TRITONBACKEND_RequestInputByIndex(request, 0, &input); + const int64_t* shape; + TRITONBACKEND_InputProperties( + input, nullptr, nullptr, &shape, nullptr, nullptr, nullptr); + if ((batchn_batch_size != -1) && + ((batch_size_offset + shape[0]) > batchn_batch_size)) { + if (response != nullptr) { + RESPOND_AND_SET_NULL_IF_ERROR( + &response, + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + std::string( + GetRequestId(request) + + "failed to split the output state tensor '" + + output_state_name + + "' in responses: expected batch size of atleast " + + std::to_string(batch_size_offset + shape[0]) + + " in model output, got " + + std::to_string(batchn_batch_size)) + .c_str())); + } + } + batchn_shape[0] = shape[0]; + batch_size_offset += shape[0]; + } + + const size_t tensor_byte_size = GetByteSize(datatype, batchn_shape); + + TRITONBACKEND_State* output_state; + if (response != nullptr) { + RESPOND_AND_SET_NULL_IF_ERROR( + &response, TRITONBACKEND_StateNew( + &output_state, request, output_state_name.c_str(), + datatype, batchn_shape.data(), batchn_shape.size())); + if (response != nullptr) { + states.push_back(output_state); + need_sync_ |= SetFixedSizeBuffer( + &response, output_state, output_state_name, tensor_byte_size, + tensor_offset, buffer, memory_type, memory_type_id, + use_pinned_memory_type, true /* state */); + } + } + + tensor_offset += tensor_byte_size; + } + + // Done with the tensor, flush any pending pinned copies. + need_sync_ |= FlushPendingPinned(buffer, memory_type, memory_type_id); +#ifdef TRITON_ENABLE_GPU + if (need_sync_ && (event_ != nullptr)) { + cudaEventRecord(event_, stream_); + } +#endif // TRITON_ENABLE_GPU + + return states; +} + +bool +BackendOutputResponder::Finalize() +{ +#ifdef TRITON_ENABLE_GPU + if ((!deferred_pinned_.empty()) && need_sync_) { + if (event_ != nullptr) { + cudaEventSynchronize(event_); + } else { + cudaStreamSynchronize(stream_); + } + need_sync_ = false; + } +#endif // TRITON_ENABLE_GPU + + // After the above sync all the GPU->pinned copies are complete. Any + // deferred copies of pinned->CPU can now be done. + for (auto& def : deferred_pinned_) { + auto pinned_memory_type = TRITONSERVER_MEMORY_CPU_PINNED; + int64_t pinned_memory_id = 0; + char* pinned_buffer = def.pinned_memory_; + + size_t offset = 0; + for (auto& pr : def.responses_) { + auto& response = pr.first; + auto& response_output = pr.second; + + bool cuda_used = false; + RESPOND_AND_SET_NULL_IF_ERROR( + response, + CopyBuffer( + response_output.name_, pinned_memory_type, pinned_memory_id, + response_output.memory_type_, response_output.memory_type_id_, + response_output.buffer_byte_size_, pinned_buffer + offset, + const_cast(response_output.buffer_), stream_, &cuda_used, + copy_on_stream_)); + need_sync_ |= cuda_used; + + offset += response_output.buffer_byte_size_; + } + } + +#ifdef TRITON_ENABLE_GPU + // Record the new event location if deferred copies occur + if ((!deferred_pinned_.empty()) && need_sync_ && (event_ != nullptr)) { + cudaEventRecord(event_, stream_); + } +#endif // TRITON_ENABLE_GPU + deferred_pinned_.clear(); + + return need_sync_; +} + + +bool +BackendOutputResponder::SetFixedSizeBuffer( + TRITONBACKEND_Response** response, void* response_output_or_state, + const std::string& output_name, const size_t tensor_byte_size, + const size_t tensor_offset, const char* tensor_buffer, + const TRITONSERVER_MemoryType tensor_memory_type, + const int64_t tensor_memory_type_id, + const TRITONSERVER_MemoryType use_pinned_memory_type, bool state) +{ + void* buffer = nullptr; + bool cuda_copy = false; + + TRITONSERVER_MemoryType actual_memory_type = tensor_memory_type; + int64_t actual_memory_type_id = tensor_memory_type_id; + + if (state) { + TRITONBACKEND_State* response_state = + reinterpret_cast(response_output_or_state); + auto err = TRITONBACKEND_StateBuffer( + response_state, &buffer, tensor_byte_size, &actual_memory_type, + &actual_memory_type_id); + if (err != nullptr) { + RESPOND_AND_SET_NULL_IF_ERROR(response, err); + return cuda_copy; + } + } else { + TRITONBACKEND_Output* response_output = + reinterpret_cast(response_output_or_state); + auto err = TRITONBACKEND_OutputBuffer( + response_output, &buffer, tensor_byte_size, &actual_memory_type, + &actual_memory_type_id); + if (err != nullptr) { + RESPOND_AND_SET_NULL_IF_ERROR(response, err); + return cuda_copy; + } + } + + // If the response buffer matches the memory type that should use an + // intermediate pinned memory buffer for the transfer, then just + // record the response as pending and increase the size required for + // the intermediate pinned buffer. + if ((use_pinned_memory_type != TRITONSERVER_MEMORY_CPU_PINNED) && + (actual_memory_type == use_pinned_memory_type)) { + if (pending_pinned_byte_size_ == 0) { + pending_pinned_offset_ = tensor_offset; + } + + pending_pinned_byte_size_ += tensor_byte_size; + pending_pinned_outputs_.push_back(std::make_pair( + response, OutputData( + output_name, buffer, tensor_byte_size, actual_memory_type, + actual_memory_type_id))); + } else { + // Direct copy without intermediate pinned memory. + bool cuda_used = false; + auto err = CopyBuffer( + output_name, tensor_memory_type, tensor_memory_type_id, + actual_memory_type, actual_memory_type_id, tensor_byte_size, + tensor_buffer + tensor_offset, buffer, stream_, &cuda_used, + copy_on_stream_); + cuda_copy |= cuda_used; + + if (err != nullptr) { + RESPOND_AND_SET_NULL_IF_ERROR(response, err); + return cuda_copy; + } + } + + return cuda_copy; +} + +bool +BackendOutputResponder::FlushPendingPinned( + const char* tensor_buffer, const TRITONSERVER_MemoryType tensor_memory_type, + const int64_t tensor_memory_type_id) +{ + bool cuda_copy = false; + + // Will be copying from CPU->pinned->GPU or GPU->pinned->CPU + + // Attempt to allocate a pinned buffer to use for staging the + // copy... if we fail to allocated the pinned buffer then we just + // directly go CPU->GPU or GPU->CPU. + char* pinned_memory = nullptr; + if (pending_pinned_byte_size_ > 0) { + TRITONSERVER_Error* err = TRITONBACKEND_MemoryManagerAllocate( + memory_manager_, reinterpret_cast(&pinned_memory), + TRITONSERVER_MEMORY_CPU_PINNED, 0 /* memory_type_id */, + pending_pinned_byte_size_); + if (err != nullptr) { + pinned_memory = nullptr; + TRITONSERVER_ErrorDelete(err); + } + } + + // If the pinned buffer wasn't actually allocated then just perform + // a direct copy. + if (pinned_memory == nullptr) { + size_t offset = 0; + for (auto& pr : pending_pinned_outputs_) { + auto& response = pr.first; + auto& response_output = pr.second; + + bool cuda_used = false; + RESPOND_AND_SET_NULL_IF_ERROR( + response, + CopyBuffer( + response_output.name_, tensor_memory_type, tensor_memory_type_id, + response_output.memory_type_, response_output.memory_type_id_, + response_output.buffer_byte_size_, + tensor_buffer + pending_pinned_offset_ + offset, + const_cast(response_output.buffer_), stream_, &cuda_used, + copy_on_stream_)); + cuda_copy |= cuda_used; + + offset += response_output.buffer_byte_size_; + } + } + // We have a pinned buffer so do a single copy of a block of tensor + // data to the pinned buffer. + else { // pinned_memory_type == TRITONSERVER_MEMORY_CPU_PINNED + bool cuda_used = false; + auto err = CopyBuffer( + "pinned buffer", tensor_memory_type, tensor_memory_type_id, + TRITONSERVER_MEMORY_CPU_PINNED, 0 /* memory_type_id */, + pending_pinned_byte_size_, tensor_buffer + pending_pinned_offset_, + pinned_memory, stream_, &cuda_used, copy_on_stream_); + cuda_copy |= cuda_used; + + // If something goes wrong with the copy all the pending + // responses fail... + if (err != nullptr) { + for (auto& pr : pending_pinned_outputs_) { + auto& response = pr.first; + if (*response != nullptr) { + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend( + *response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, err), + "failed to send TensorFlow error response"); + *response = nullptr; + } + } + TRITONSERVER_ErrorDelete(err); + } + + // If the copy was not async (i.e. if tensor was in CPU so a + // CPU->CPU-PINNED copy was performed above), then the pinned + // buffer now holds the tensor contents and we can immediately + // issue the copies from the pinned buffer to the + // responses. + // + // Otherwise the GPU->CPU-PINNED async copies are in flight and we + // simply remember the pinned buffer and the corresponding + // response outputs so that we can do the pinned->CPU copies in + // finalize after we have waited for all async copies to complete. + if (!cuda_used) { + size_t offset = 0; + for (auto& pr : pending_pinned_outputs_) { + auto& response = pr.first; + auto& response_output = pr.second; + + bool cuda_used = false; + RESPOND_AND_SET_NULL_IF_ERROR( + response, + CopyBuffer( + response_output.name_, TRITONSERVER_MEMORY_CPU_PINNED, + 0 /* memory_type_id */, response_output.memory_type_, + response_output.memory_type_id_, + response_output.buffer_byte_size_, pinned_memory + offset, + const_cast(response_output.buffer_), stream_, &cuda_used, + copy_on_stream_)); + cuda_copy |= cuda_used; + + offset += response_output.buffer_byte_size_; + } + } else { + deferred_pinned_.emplace_back( + pinned_memory, pending_pinned_byte_size_, + std::move(pending_pinned_outputs_)); + } + } + + // Pending pinned copies are handled... + pending_pinned_byte_size_ = 0; + pending_pinned_offset_ = 0; + pending_pinned_outputs_.clear(); + + // Need to hold on to the allocated pinned buffer as there are still + // copies in flight. Will delete it in finalize. + if (pinned_memory != nullptr) { + pinned_memories_.push_back(pinned_memory); + } + + return cuda_copy; +} + +void +BackendOutputResponder::ProcessBatchOutput( + const std::string& name, const BatchOutput& batch_output, + const char* buffer, const TRITONSERVER_MemoryType memory_type, + const int64_t memory_type_id) +{ + // A value of CPU_PINNED indicates that pinned memory buffer is not + // needed for this tensor. Any other value indicates that a pinned + // memory buffer is needed when the target memory type matches + // 'use_pinned_memory_type'. + TRITONSERVER_MemoryType use_pinned_memory_type = + TRITONSERVER_MEMORY_CPU_PINNED; + if (pinned_enabled_) { + use_pinned_memory_type = GetUsePinnedMemoryType(memory_type); + } + + // Batch output may be processed differently based on the kind + switch (batch_output.BatchOutputKind()) { + case BatchOutput::Kind::BATCH_SCATTER_WITH_INPUT_SHAPE: { + const auto& output_name = batch_output.TargetNames()[0]; + const auto& input_name = batch_output.SourceInputs()[0]; + const auto& datatype = batch_output.DataType(); + size_t tensor_offset = 0; + + for (size_t idx = 0; idx < responses_->size(); idx++) { + auto& request = requests_[idx]; + auto& response = (*responses_)[idx]; + + // If then pending copies are from tensor buffer that is not + // contiguous with 'response's part of that buffer, then need to + // go ahead and perform the pending copies so that can start a + // new contiguous region if necessary. + if ((pending_pinned_byte_size_ > 0) && + (tensor_offset != + (pending_pinned_byte_size_ + pending_pinned_offset_))) { + need_sync_ |= FlushPendingPinned(buffer, memory_type, memory_type_id); + } + + // Override shape to be correct for this response, with a naive + // assumption that the dynamic dimension in output is mapped to the same + // dimension in the input + auto output_batchn_shape = batch_output.OutputShape(); + { + TRITONBACKEND_Input* input; + TRITONBACKEND_RequestInput(request, input_name.c_str(), &input); + const int64_t* shape; + TRITONBACKEND_InputProperties( + input, nullptr, nullptr, &shape, nullptr, nullptr, nullptr); + for (size_t dim_idx = 0; dim_idx < output_batchn_shape.size(); + dim_idx++) { + if (output_batchn_shape[dim_idx] == -1) { + output_batchn_shape[dim_idx] = shape[dim_idx]; + } + } + } + + const size_t tensor_byte_size = + GetByteSize(datatype, output_batchn_shape); + + TRITONBACKEND_Output* response_output; + if (response != nullptr) { + uint32_t output_count; + RESPOND_AND_SET_NULL_IF_ERROR( + &response, + TRITONBACKEND_RequestOutputCount(request, &output_count)); + if (response != nullptr) { + for (uint32_t output_idx = 0; output_idx < output_count; + output_idx++) { + const char* name; + RESPOND_AND_SET_NULL_IF_ERROR( + &response, + TRITONBACKEND_RequestOutputName(request, output_idx, &name)); + if ((response != nullptr) && (output_name == name)) { + RESPOND_AND_SET_NULL_IF_ERROR( + &response, TRITONBACKEND_ResponseOutput( + response, &response_output, name, datatype, + output_batchn_shape.data(), + output_batchn_shape.size())); + if (response != nullptr) { + need_sync_ |= SetFixedSizeBuffer( + &response, response_output, output_name, tensor_byte_size, + tensor_offset, buffer, memory_type, memory_type_id, + use_pinned_memory_type, false /* state */); + } + + break; + } + } + } + } + + tensor_offset += tensor_byte_size; + } + break; + } + } + + // Done with the tensor, flush any pending pinned copies. + need_sync_ |= FlushPendingPinned(buffer, memory_type, memory_type_id); +#ifdef TRITON_ENABLE_GPU + if (need_sync_ && (event_ != nullptr)) { + cudaEventRecord(event_, stream_); + } +#endif // TRITON_ENABLE_GPU +} + +}} // namespace triton::backend diff --git a/3rdparty/backend-r22.12/src/kernel.cu b/3rdparty/backend-r22.12/src/kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..9f24dd0bdd0bf46b3573c2ad52b9bcb873dc0b6a --- /dev/null +++ b/3rdparty/backend-r22.12/src/kernel.cu @@ -0,0 +1,81 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "kernel.h" + +#include + +#define THREADBLOCK_SIZE 512 +__launch_bounds__(THREADBLOCK_SIZE) __global__ void TritonGatherKernel( + const int8_t** __restrict input_ptr_buffer, + const size_t* __restrict byte_size_buffer, + const size_t* __restrict byte_size_offset_buffer, + int8_t* __restrict output_buffer) +{ + int request_idx = blockIdx.x; + int lane_id = threadIdx.x; + const int8_t* request_input_buffer = input_ptr_buffer[request_idx]; + int byte_size = byte_size_buffer[request_idx]; + int byte_size_offset = byte_size_offset_buffer[request_idx]; + + int8_t* output_buffer_with_offset = output_buffer + byte_size_offset; + if (((byte_size % 4) == 0) && (((uint64_t)request_input_buffer % 4) == 0) && + (((uint64_t)output_buffer_with_offset % 4) == 0)) { + int32_t* input_4 = (int32_t*)request_input_buffer; + int32_t* output_4 = (int32_t*)output_buffer_with_offset; + int element_count = byte_size / 4; + for (int elem_id = lane_id; elem_id < element_count; + elem_id += THREADBLOCK_SIZE) { + output_4[elem_id] = input_4[elem_id]; + } + } else { + for (int elem_id = lane_id; elem_id < byte_size; + elem_id += THREADBLOCK_SIZE) { + output_buffer_with_offset[elem_id] = + __ldg(request_input_buffer + elem_id); + } + } +} + +#ifdef __cplusplus +extern "C" { +#endif + +cudaError_t +RunGatherKernel( + const int8_t** input_ptr_buffer, const size_t* byte_size_buffer, + const size_t* byte_size_offset_buffer, int8_t* output_buffer, + size_t request_count, cudaStream_t stream) +{ + TritonGatherKernel<<>>( + input_ptr_buffer, byte_size_buffer, byte_size_offset_buffer, + output_buffer); + return cudaGetLastError(); +} + +#ifdef __cplusplus +} +#endif diff --git a/3rdparty/backend-r22.12/src/kernel.h b/3rdparty/backend-r22.12/src/kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..948d3051e472d53c3081cffbebe450a7e415f430 --- /dev/null +++ b/3rdparty/backend-r22.12/src/kernel.h @@ -0,0 +1,42 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +cudaError_t RunGatherKernel( + const int8_t** input_ptr_buffer, const size_t* byte_size_buffer, + const size_t* byte_size_offset_buffer, int8_t* output_buffer, + size_t request_count, cudaStream_t stream); + +#ifdef __cplusplus +} +#endif diff --git a/3rdparty/common-r22.12/.clang-format b/3rdparty/common-r22.12/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..98c649734c29e0b1d134dae65be9bc08a14b4ba5 --- /dev/null +++ b/3rdparty/common-r22.12/.clang-format @@ -0,0 +1,37 @@ +--- +BasedOnStyle: Google + +IndentWidth: 2 +ContinuationIndentWidth: 4 +UseTab: Never +MaxEmptyLinesToKeep: 2 + +SortIncludes: true +CompactNamespaces: true +ReflowComments: true + +DerivePointerAlignment: false +PointerAlignment: Left + +AllowShortIfStatementsOnASingleLine: false +AllowShortBlocksOnASingleLine: false +AllowShortFunctionsOnASingleLine: Inline + +AlwaysBreakAfterReturnType: TopLevelDefinitions +AlignAfterOpenBracket: AlwaysBreak +BreakBeforeBraces: Custom +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: true + AfterNamespace: false + AfterStruct: false + AfterUnion: false + BeforeCatch: true + +BinPackArguments: true +BinPackParameters: true +ConstructorInitializerAllOnOneLineOrOnePerLine: false + +IndentCaseLabels: true \ No newline at end of file diff --git a/3rdparty/common-r22.12/.gitignore b/3rdparty/common-r22.12/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0e9f099a2eef4742716637e3cce3a45f7053b021 --- /dev/null +++ b/3rdparty/common-r22.12/.gitignore @@ -0,0 +1,3 @@ +/build +/.vscode +*.so diff --git a/3rdparty/common-r22.12/LICENSE b/3rdparty/common-r22.12/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a6bd4f2f5b3ecd75917ae39f06e6b5521c414491 --- /dev/null +++ b/3rdparty/common-r22.12/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of NVIDIA CORPORATION nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/3rdparty/common-r22.12/README.md b/3rdparty/common-r22.12/README.md new file mode 100644 index 0000000000000000000000000000000000000000..96e9e7bc68b80f9fec1e27b8da3c935fe9b6344b --- /dev/null +++ b/3rdparty/common-r22.12/README.md @@ -0,0 +1,51 @@ + + +[![License](https://img.shields.io/badge/License-BSD3-lightgrey.svg)](https://opensource.org/licenses/BSD-3-Clause) + +# Triton Inference Server Common + +Common source, scripts and utilities shared across all Triton +repositories. + +This repo is not typically built directly but is instead included in +the build of other repos. To build directly first install the required +dependencies. + +``` +$ apt-get install rapidjson-dev +``` + +Use cmake 3.17 or later to build and install in a local directory. + +``` +$ mkdir build +$ cd build +$ cmake -DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install .. +$ make install +``` diff --git a/3rdparty/common-r22.12/cmake/TritonCommonConfig.cmake.in b/3rdparty/common-r22.12/cmake/TritonCommonConfig.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..56cb1461c00366eb321bc4696eac157ee4f43bf9 --- /dev/null +++ b/3rdparty/common-r22.12/cmake/TritonCommonConfig.cmake.in @@ -0,0 +1,51 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +@PACKAGE_INIT@ + +set_and_check(TRITONCOMMON_CMAKE_DIR "${CMAKE_CURRENT_LIST_DIR}") + +list(APPEND CMAKE_MODULE_PATH ${TRITONCOMMON_CMAKE_DIR}) + +include(CMakeFindDependencyMacro) +find_dependency(Threads) + +if(NOT TARGET TritonCommon::triton-common-json) + include("${TRITONCOMMON_CMAKE_DIR}/TritonCommonTargets.cmake") +endif() + +check_required_components(triton-common-json + triton-common-sync-queue + triton-common-async-work-queue + triton-common-thread-pool +) + +set(TRITONCOMMON_LIBRARIES + TritonCommon::triton-common-json + TritonCommon::triton-common-sync-queue + TritonCommon::triton-common-async-work-queue + TritonCommon::triton-common-thread-pool +) diff --git a/3rdparty/common-r22.12/include/triton/common/async_work_queue.h b/3rdparty/common-r22.12/include/triton/common/async_work_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..40afe7bb1af573429eaf0285deab77e5caa0a74a --- /dev/null +++ b/3rdparty/common-r22.12/include/triton/common/async_work_queue.h @@ -0,0 +1,59 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include "error.h" +#include "thread_pool.h" + +namespace triton { namespace common { +// Manager for asynchronous worker threads. Use to accelerate copies and +// other such operations by running them in parallel. +// Call Initialize to start the worker threads (once) and AddTask to tasks to +// the queue. + +class AsyncWorkQueue { + public: + // Start 'worker_count' number of worker threads. + static Error Initialize(size_t worker_count); + + // Get the number of worker threads. + static size_t WorkerCount(); + + // Add a 'task' to the queue. The function will take ownership of 'task'. + // Therefore std::move should be used when calling AddTask. + static Error AddTask(std::function&& task); + + protected: + static void Reset(); + + private: + AsyncWorkQueue() = default; + ~AsyncWorkQueue(); + static AsyncWorkQueue* GetSingleton(); + std::unique_ptr thread_pool_; +}; + +}} // namespace triton::common diff --git a/3rdparty/common-r22.12/include/triton/common/error.h b/3rdparty/common-r22.12/include/triton/common/error.h new file mode 100644 index 0000000000000000000000000000000000000000..cf8d30896ddcebd884f9d5b6f82ecbee31c95319 --- /dev/null +++ b/3rdparty/common-r22.12/include/triton/common/error.h @@ -0,0 +1,78 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include + +namespace triton { namespace common { + +// +// Error +// +// Error returned by utilities from common repo. +// +class Error { + public: + enum class Code { + SUCCESS, + UNKNOWN, + INTERNAL, + NOT_FOUND, + INVALID_ARG, + UNAVAILABLE, + UNSUPPORTED, + ALREADY_EXISTS + }; + + explicit Error(Code code = Code::SUCCESS) : code_(code) {} + explicit Error(Code code, const std::string& msg) : code_(code), msg_(msg) {} + + // Convenience "success" value. Can be used as Error::Success to + // indicate no error. + static const Error Success; + + // Return the code for this status. + Code ErrorCode() const { return code_; } + + // Return the message for this status. + const std::string& Message() const { return msg_; } + + // Return true if this status indicates "ok"/"success", false if + // status indicates some kind of failure. + bool IsOk() const { return code_ == Code::SUCCESS; } + + // Return the status as a string. + std::string AsString() const; + + // Return the constant string name for a code. + static const char* CodeString(const Code code); + + protected: + Code code_; + std::string msg_; +}; + +}} // namespace triton::common diff --git a/3rdparty/common-r22.12/include/triton/common/logging.h b/3rdparty/common-r22.12/include/triton/common/logging.h new file mode 100644 index 0000000000000000000000000000000000000000..a52c0c1918558c5e75e0d90ad8ea051d3a23bf93 --- /dev/null +++ b/3rdparty/common-r22.12/include/triton/common/logging.h @@ -0,0 +1,229 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace triton { namespace common { + +// A log message. +class LogMessage { + public: + // Log levels. + enum Level { kERROR = 0, kWARNING = 1, kINFO = 2 }; + + LogMessage(const char* file, int line, uint32_t level); + ~LogMessage(); + + std::stringstream& stream() { return stream_; } + + private: + static const std::vector level_name_; + std::stringstream stream_; +}; + +// Global logger for messages. Controls how log messages are reported. +class Logger { + public: + enum class Format { kDEFAULT, kISO8601 }; + + Logger(); + + // Is a log level enabled. + bool IsEnabled(LogMessage::Level level) const { return enables_[level]; } + + // Set enable for a log Level. + void SetEnabled(LogMessage::Level level, bool enable) + { + enables_[level] = enable; + } + + // Get the current verbose logging level. + uint32_t VerboseLevel() const { return vlevel_; } + + // Set the current verbose logging level. + void SetVerboseLevel(uint32_t vlevel) { vlevel_ = vlevel; } + + // Get the logging format. + Format LogFormat() { return format_; } + + // Get the logging format as a string. + std::string LogFormatString() + { + switch (format_) { + case Format::kISO8601: + return "ISO8601"; + case Format::kDEFAULT: + return "default"; + default: + return "Invalid format"; + } + } + + // Set the logging format. + void SetLogFormat(Format format) { format_ = format; } + + // Get the log output file name. + const std::string& LogFile() { return filename_; } + + // Set the log output file. Returns an empty string upon + // success, else returns an error string. + const std::string SetLogFile(const std::string& filename) + { + const std::lock_guard lock(mutex_); + file_stream_.close(); + std::string revert_name(filename_); + filename_ = filename; + if (!filename_.empty()) { + file_stream_.open(filename_, std::ios::app); + if (file_stream_.fail()) { + std::stringstream error; + error << __FILE__ << " " << __LINE__ + << ": Failed to open log file: " << std::strerror(errno) + << std::endl; + filename_ = revert_name; + file_stream_.open(filename_, std::ios::app); + return error.str(); + } + } + // will return an empty string + return std::string(); + } + + // Log a message. + void Log(const std::string& msg); + + // Flush the log. + void Flush(); + + private: + std::vector enables_; + uint32_t vlevel_; + Format format_; + std::mutex mutex_; + std::string filename_; + std::ofstream file_stream_; +}; + +extern Logger gLogger_; + +#define LOG_ENABLE_INFO(E) \ + triton::common::gLogger_.SetEnabled( \ + triton::common::LogMessage::Level::kINFO, (E)) +#define LOG_ENABLE_WARNING(E) \ + triton::common::gLogger_.SetEnabled( \ + triton::common::LogMessage::Level::kWARNING, (E)) +#define LOG_ENABLE_ERROR(E) \ + triton::common::gLogger_.SetEnabled( \ + triton::common::LogMessage::Level::kERROR, (E)) +#define LOG_SET_VERBOSE(L) \ + triton::common::gLogger_.SetVerboseLevel( \ + static_cast(std::max(0, (L)))) +#define LOG_SET_OUT_FILE(FN) triton::common::gLogger_.SetLogFile((FN)) +#define LOG_SET_FORMAT(F) triton::common::gLogger_.SetLogFormat((F)) + +#define LOG_VERBOSE_LEVEL triton::common::gLogger_.VerboseLevel() +#define LOG_FORMAT triton::common::gLogger_.LogFormat() +#define LOG_FORMAT_STRING triton::common::gLogger_.LogFormatString() +#define LOG_FILE triton::common::gLogger_.LogFile() + +#ifdef TRITON_ENABLE_LOGGING + +#define LOG_INFO_IS_ON \ + triton::common::gLogger_.IsEnabled(triton::common::LogMessage::Level::kINFO) +#define LOG_WARNING_IS_ON \ + triton::common::gLogger_.IsEnabled( \ + triton::common::LogMessage::Level::kWARNING) +#define LOG_ERROR_IS_ON \ + triton::common::gLogger_.IsEnabled(triton::common::LogMessage::Level::kERROR) +#define LOG_VERBOSE_IS_ON(L) (triton::common::gLogger_.VerboseLevel() >= (L)) + +#else + +// If logging is disabled, define macro to be false to avoid further evaluation +#define LOG_INFO_IS_ON false +#define LOG_WARNING_IS_ON false +#define LOG_ERROR_IS_ON false +#define LOG_VERBOSE_IS_ON(L) false + +#endif // TRITON_ENABLE_LOGGING + +// Macros that use explicitly given filename and line number. +#define LOG_INFO_FL(FN, LN) \ + if (LOG_INFO_IS_ON) \ + triton::common::LogMessage( \ + (char*)(FN), LN, triton::common::LogMessage::Level::kINFO) \ + .stream() +#define LOG_WARNING_FL(FN, LN) \ + if (LOG_WARNING_IS_ON) \ + triton::common::LogMessage( \ + (char*)(FN), LN, triton::common::LogMessage::Level::kWARNING) \ + .stream() +#define LOG_ERROR_FL(FN, LN) \ + if (LOG_ERROR_IS_ON) \ + triton::common::LogMessage( \ + (char*)(FN), LN, triton::common::LogMessage::Level::kERROR) \ + .stream() +#define LOG_VERBOSE_FL(L, FN, LN) \ + if (LOG_VERBOSE_IS_ON(L)) \ + triton::common::LogMessage( \ + (char*)(FN), LN, triton::common::LogMessage::Level::kINFO) \ + .stream() + +// Macros that use current filename and line number. +#define LOG_INFO LOG_INFO_FL(__FILE__, __LINE__) +#define LOG_WARNING LOG_WARNING_FL(__FILE__, __LINE__) +#define LOG_ERROR LOG_ERROR_FL(__FILE__, __LINE__) +#define LOG_VERBOSE(L) LOG_VERBOSE_FL(L, __FILE__, __LINE__) + + +#define LOG_STATUS_ERROR(X, MSG) \ + do { \ + const Status& status__ = (X); \ + if (!status__.IsOk()) { \ + LOG_ERROR << (MSG) << ": " << status__.AsString(); \ + } \ + } while (false) + +#define LOG_TRITONSERVER_ERROR(X, MSG) \ + do { \ + TRITONSERVER_Error* err__ = (X); \ + if (err__ != nullptr) { \ + LOG_ERROR << (MSG) << ": " << TRITONSERVER_ErrorCodeString(err__) \ + << " - " << TRITONSERVER_ErrorMessage(err__); \ + TRITONSERVER_ErrorDelete(err__); \ + } \ + } while (false) + +#define LOG_FLUSH triton::common::gLogger_.Flush() + +}} // namespace triton::common diff --git a/3rdparty/common-r22.12/include/triton/common/model_config.h b/3rdparty/common-r22.12/include/triton/common/model_config.h new file mode 100644 index 0000000000000000000000000000000000000000..468f678cb6676083f2669c67534eb11e9787673e --- /dev/null +++ b/3rdparty/common-r22.12/include/triton/common/model_config.h @@ -0,0 +1,243 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include "model_config.pb.h" + +namespace triton { namespace common { + +/// The type for a repeated dims field (used for shape). +using DimsList = ::google::protobuf::RepeatedField<::google::protobuf::int64>; + +/// The type for the metric_tags map. +using MetricTagsMap = ::google::protobuf::Map; + +// Map from a host policy name to map of cmdline +// settings for the host policy. +using HostPolicyCmdlineConfig = std::map; +using HostPolicyCmdlineConfigMap = + std::unordered_map; + +// Map from backend name to list of setting=value pairs of cmdline +// settings for the backend. +using BackendCmdlineConfig = std::vector>; +using BackendCmdlineConfigMap = + std::unordered_map; + +/// The value for a dimension in a shape that indicates that that +/// dimension can take on any size. +constexpr int WILDCARD_DIM = -1; + +constexpr int SCHEDULER_DEFAULT_NICE = 5; + +/// Enumeration for the different platform types. +enum Platform { + PLATFORM_UNKNOWN = 0, + PLATFORM_TENSORRT_PLAN = 1, + PLATFORM_TENSORFLOW_GRAPHDEF = 2, + PLATFORM_TENSORFLOW_SAVEDMODEL = 3, + PLATFORM_ENSEMBLE = 4, + PLATFORM_ONNXRUNTIME_ONNX = 5, + PLATFORM_PYTORCH_LIBTORCH = 6 +}; + +/// Get the number of elements in a shape. +/// \param dims The shape. +/// \return The number of elements, or -1 if the number of elements +/// cannot be determined because the shape contains one or more +/// wilcard dimensions. +int64_t GetElementCount(const DimsList& dims); + +/// Get the number of elements in a shape. +/// \param dims The shape. +/// \return The number of elements, or -1 if the number of elements +/// cannot be determined because the shape contains one or more +/// wilcard dimensions. +int64_t GetElementCount(const std::vector& dims); + +/// Get the number of elements in the shape of a model input. +/// \param mio The model input. +/// \return The number of elements, or -1 if the number of elements +/// cannot be determined because the shape contains one or more +/// wilcard dimensions. +int64_t GetElementCount(const inference::ModelInput& mio); + +/// Get the number of elements in the shape of a model output. +/// \param mio The model output. +/// \return The number of elements, or -1 if the number of elements +/// cannot be determined because the shape contains one or more +/// wilcard dimensions. +int64_t GetElementCount(const inference::ModelOutput& mio); + +/// Are values of a datatype fixed-size, or variable-sized. +/// \param dtype The data-type. +/// \return True if datatype values are fixed-sized, false if +/// variable-sized. +bool IsFixedSizeDataType(const inference::DataType dtype); + +/// Get the size of objects of a given datatype in bytes. +/// \param dtype The data-type. +/// \return The size, in bytes, of objects of the datatype, or 0 if +/// size cannot be determine (for example, values of type TYPE_STRING +/// have variable length and so size cannot be determine just from the +/// type). +size_t GetDataTypeByteSize(const inference::DataType dtype); + +/// Get the size, in bytes, of a tensor based on datatype and +/// shape. +/// \param dtype The data-type. +/// \param dims The shape. +/// \return The size, in bytes, of the corresponding tensor, or -1 if +/// unable to determine the size. +int64_t GetByteSize(const inference::DataType& dtype, const DimsList& dims); + +/// Get the size, in bytes, of a tensor based on datatype and +/// shape. +/// \param dtype The data-type. +/// \param dims The shape. +/// \return The size, in bytes, of the corresponding tensor, or -1 if +/// unable to determine the size. +int64_t GetByteSize( + const inference::DataType& dtype, const std::vector& dims); + +/// Get the size, in bytes, of a tensor based on batch-size, datatype +/// and shape. A tensor that has empty shape [] and non-zero +/// batch-size is sized as a tensor with shape [ batch-size ]. +/// \param batch_size The batch-size. May be 0 to indicate no +/// batching. +/// \param dtype The data-type. +/// \param dims The shape. +/// \return The size, in bytes, of the corresponding tensor, or -1 if +/// unable to determine the size. +int64_t GetByteSize( + const int batch_size, const inference::DataType& dtype, + const DimsList& dims); + +/// Get the size, in bytes, of a tensor based on batch-size, datatype +/// and shape. A tensor that has empty shape [] and non-zero +/// batch-size is sized as a tensor with shape [ batch-size ]. +/// \param batch_size The batch-size. May be 0 to indicate no +/// batching. +/// \param dtype The data-type. +/// \param dims The shape. +/// \return The size, in bytes, of the corresponding tensor, or -1 if +/// unable to determine the size. +int64_t GetByteSize( + const int batch_size, const inference::DataType& dtype, + const std::vector& dims); + +/// Get the size, in bytes, of a tensor based on ModelInput. +/// \param mio The ModelInput protobuf. +/// \return The size, in bytes, of the corresponding tensor, or -1 if +/// unable to determine the size. +int64_t GetByteSize(const inference::ModelInput& mio); + +/// Get the size, in bytes, of a tensor based on ModelOutput. +/// \param mio The ModelOutput protobuf. +/// \return The size, in bytes, of the corresponding tensor, or -1 if +/// unable to determine the size. +int64_t GetByteSize(const inference::ModelOutput& mio); + +/// Get the CPU thread nice level associate with a model +/// configuration's priority. +/// \param config The model configuration. +/// \return The nice level. +int GetCpuNiceLevel(const inference::ModelConfig& config); + +/// Compare two model configuration shapes for equality. Wildcard +/// dimensions (that is, dimensions with size WILDCARD_DIM) are +/// compared literally so that to be equal the two shapes must both +/// specify WILDCARD_DIM in the same dimensions. +/// \params dims0 The first shape. +/// \params dims1 The second shape. +/// \return True if the shapes are equal, false if not equal. +bool CompareDims(const DimsList& dims0, const DimsList& dims1); + +/// Compare two model configuration shapes for equality. Wildcard +/// dimensions (that is, dimensions with size WILDCARD_DIM) are +/// compared literally so that to be equal the two shapes must both +/// specify WILDCARD_DIM in the same dimensions. +/// \params dims0 The first shape. +/// \params dims1 The second shape. +/// \return True if the shapes are equal, false if not equal. +bool CompareDims( + const std::vector& dims0, const std::vector& dims1); + +/// Compare two model configuration shapes for equality. Wildcard +/// dimensions (that is, dimensions with size WILDCARD_DIM) are +/// allowed to match with any value. So, a dimension in one shape +/// specified as WILDCARD_DIM will always match the same dimension in +/// the other shape. +/// \params dims0 The first shape. +/// \params dims1 The second shape. +/// \return True if the shapes are equal, false if not equal. +bool CompareDimsWithWildcard(const DimsList& dims0, const DimsList& dims1); + +/// Compare two model configuration shapes for equality. Wildcard +/// dimensions (that is, dimensions with size WILDCARD_DIM) are +/// allowed to match with any value. So, a dimension in one shape +/// specified as WILDCARD_DIM will always match the same dimension in +/// the other shape. +/// \params dims0 The first shape. +/// \params dims1 The second shape. +/// \return True if the shapes are equal, false if not equal. +bool CompareDimsWithWildcard( + const DimsList& dims0, const std::vector& dims1); + +/// Convert a DimsList to string representation. +/// \param dims The DimsList to be converted. +/// \return String representation of the DimsList in pattern +/// "[d0,d1,...,dn]" +std::string DimsListToString(const DimsList& dims); + +/// Convert a vector representing a shape to string representation. +/// \param dims The vector of dimensions to be converted. +/// \return String representation of the vector in pattern +/// "[d0,d1,...,dn]" +std::string DimsListToString( + const std::vector& dims, const int start_idx = 0); + +/// Get the server protocol string representation of a datatype. +/// \param dtype The data type. +/// \return The string representation. +const char* DataTypeToProtocolString(const inference::DataType dtype); + +/// Get the datatype corresponding to a server protocol string +/// representation of a datatype. +/// \param dtype string representation. +/// \return The data type. +inference::DataType ProtocolStringToDataType(const std::string& dtype); + +/// Get the datatype corresponding to a server protocol string +/// representation of a datatype. +/// \param dtype Pointer to string. +/// \param len Length of the string. +/// \return The data type. +inference::DataType ProtocolStringToDataType(const char* dtype, size_t len); + +}} // namespace triton::common diff --git a/3rdparty/common-r22.12/include/triton/common/nvtx.h b/3rdparty/common-r22.12/include/triton/common/nvtx.h new file mode 100644 index 0000000000000000000000000000000000000000..450736cc5ffb11cbed448f51f6f607caceb325ce --- /dev/null +++ b/3rdparty/common-r22.12/include/triton/common/nvtx.h @@ -0,0 +1,59 @@ +// Copyright 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#ifdef TRITON_ENABLE_NVTX + +#include + +namespace triton { namespace common { + +// Updates a server stat with duration measured by a C++ scope. +class NvtxRange { + public: + explicit NvtxRange(const char* label) { nvtxRangePushA(label); } + + explicit NvtxRange(const std::string& label) : NvtxRange(label.c_str()) {} + + ~NvtxRange() { nvtxRangePop(); } +}; + +}} // namespace triton::common + +#endif // TRITON_ENABLE_NVTX + +// +// Macros to access NVTX functionality +// +#ifdef TRITON_ENABLE_NVTX +#define NVTX_INITIALIZE nvtxInitialize(nullptr) +#define NVTX_RANGE(V, L) triton::common::NvtxRange V(L) +#define NVTX_MARKER(L) nvtxMarkA(L) +#else +#define NVTX_INITIALIZE +#define NVTX_RANGE(V, L) +#define NVTX_MARKER(L) +#endif // TRITON_ENABLE_NVTX diff --git a/3rdparty/common-r22.12/include/triton/common/sync_queue.h b/3rdparty/common-r22.12/include/triton/common/sync_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..2ab7a40cf525f27c6f093c4741defaec81013fcd --- /dev/null +++ b/3rdparty/common-r22.12/include/triton/common/sync_queue.h @@ -0,0 +1,83 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include + +namespace triton { namespace common { + +// +// C++11 doesn't have a sync queue so we implement a simple one. +// +template +class SyncQueue { + public: + SyncQueue() {} + + bool Empty() + { + std::lock_guard lk(mu_); + return queue_.empty(); + } + + Item Get() + { + std::unique_lock lk(mu_); + if (queue_.empty()) { + cv_.wait(lk, [this] { return !queue_.empty(); }); + } + auto res = std::move(queue_.front()); + queue_.pop_front(); + return res; + } + + void Put(const Item& value) + { + { + std::lock_guard lk(mu_); + queue_.push_back(value); + } + cv_.notify_all(); + } + + void Put(Item&& value) + { + { + std::lock_guard lk(mu_); + queue_.push_back(std::move(value)); + } + cv_.notify_all(); + } + + private: + std::mutex mu_; + std::condition_variable cv_; + std::deque queue_; +}; + +}} // namespace triton::common diff --git a/3rdparty/common-r22.12/include/triton/common/table_printer.h b/3rdparty/common-r22.12/include/triton/common/table_printer.h new file mode 100644 index 0000000000000000000000000000000000000000..230e6c3043c1d74598805ab17891229d37bb54e0 --- /dev/null +++ b/3rdparty/common-r22.12/include/triton/common/table_printer.h @@ -0,0 +1,79 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include + +namespace triton { namespace common { + +// +// An ASCII table printer. +// +class TablePrinter { + public: + // Insert a row at the end of the table + void InsertRow(const std::vector& row); + + // Print the table + std::string PrintTable(); + + // TablePrinter will take the ownership of `headers`. + TablePrinter(const std::vector& headers); + + private: + // Update the `shares_` such that all the excess + // amount of space not used a column is fairly allocated + // to the other columns + void FairShare(); + + // Append a row to `table`. This function handles the cases where a wrapping + // occurs. + void AddRow(std::stringstream& table, size_t row_index); + + // Add a row divider + void AddRowDivider(std::stringstream& table); + + // Max row width + std::vector max_widths_; + + // Max row height + std::vector max_heights_; + + // A vector of vectors of vectors containing data items for every column + // The record is stored in a vector of string, where each of the vector items + // contains a single line from the record. For example, ["Item 1", "Item 2", + // "Item 3\n Item 3 line 2"] will be stored as [["Item 1"], ["Item 2"], ["Item + // 3", "Item 3 line 2"]] + std::vector>> data_; + + // Fair share of every column + std::vector shares_; +}; + +}} // namespace triton::common diff --git a/3rdparty/common-r22.12/include/triton/common/thread_pool.h b/3rdparty/common-r22.12/include/triton/common/thread_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..787d4c1199e66d62d35f1af5e60497d1224035f0 --- /dev/null +++ b/3rdparty/common-r22.12/include/triton/common/thread_pool.h @@ -0,0 +1,61 @@ +// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include + +namespace triton { namespace common { + +// Generic fixed-size Thread Pool to execute tasks asynchronously + +class ThreadPool { + public: + explicit ThreadPool(std::size_t thread_count); + ~ThreadPool(); + ThreadPool(const ThreadPool&) = delete; + ThreadPool& operator=(const ThreadPool&) = delete; + + using Task = std::function; + // Assigns "task" to the task queue for a worker thread to execute when + // available. This will not track the return value of the task. + void Enqueue(Task&& task); + // Returns the number of threads in thread pool + size_t Size() { return workers_.size(); } + + private: + std::queue task_queue_; + std::mutex queue_mtx_; + std::condition_variable cv_; + std::vector workers_; + // If true, tells pool to stop accepting work and tells awake worker threads + // to exit when no tasks are left on the queue. + bool stop_ = false; +}; + +}} // namespace triton::common diff --git a/3rdparty/common-r22.12/include/triton/common/triton_json.h b/3rdparty/common-r22.12/include/triton/common/triton_json.h new file mode 100644 index 0000000000000000000000000000000000000000..68d1cc39aee457726358990f736501701d43d2fc --- /dev/null +++ b/3rdparty/common-r22.12/include/triton/common/triton_json.h @@ -0,0 +1,1119 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#ifdef _WIN32 +// Remove GetObject definition from windows.h, which prevents calls to +// RapidJSON's GetObject. +// https://github.com/Tencent/rapidjson/issues/1448 +#undef GetObject +#include +#else +// Disable class-memaccess warning to facilitate compilation with gcc>7 +// https://github.com/Tencent/rapidjson/issues/1700 +#pragma GCC diagnostic push +#if defined(__GNUC__) && __GNUC__ >= 8 +#pragma GCC diagnostic ignored "-Wclass-memaccess" +#endif +#include +#pragma GCC diagnostic pop +#endif // _WIN32 + +#include // CrtAllocator (default) for Writer instantiation +#include // UTF8 (default) for Writer instantiation +#include +#include +#include +#include +#include +#include +#include + +// This header can be used both within Triton server and externally +// (i.e. in source that interacts only via TRITONSERVER or +// TRITONBACKEND API). Status is handled differently in these cases so +// the following macros must be defined before including this +// header. As an example the defines are shown here as returned by the +// TRITONSERVER API. +// +// #define TRITONJSON_STATUSTYPE TRITONSERVER_Error* +// #define TRITONJSON_STATUSRETURN(M) +// return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, (M).c_str()) +// #define TRITONJSON_STATUSSUCCESS nullptr + +namespace triton { namespace common { + +// +// A JSON parser/writer. Currently based on rapidjson but the intent +// is to provide an abstraction for JSON functions that make it easy +// to substitute a different JSON parser. Specifically for rapidjson +// the class is also designed to provide safe access and error +// reporting to avoid the cases where rapidjson would just abort the +// entire application (!). +// +class TritonJson { + public: + class Value; + enum class ValueType { + OBJECT = rapidjson::kObjectType, + ARRAY = rapidjson::kArrayType, + }; + + // + // Buffer used when writing JSON representation. + // + class WriteBuffer { + public: + // Get buffer base address. + const char* Base() const { return buffer_.c_str(); } + + // Get a reference to the buffer itself. Useful to efficiently + // move the contents out of the buffer. + std::string& MutableContents() { return buffer_; } + + // Immutable contents. + const std::string& Contents() const { return buffer_; } + + // Interface required by rapidjson::Writer + typedef char Ch; + void Put(char c) { buffer_.push_back(c); } + void Clear() { buffer_.clear(); } + void Flush() { return; } + size_t Size() const { return buffer_.size(); } + + private: + std::string buffer_; + }; + + // + // Value representing the entire document or an element within a + // document. + // + class Value { + public: + // Empty value. Will become a top-level Document value if + // initialized by parsing or a non-top-level value if initialized + // any other way. + explicit Value() : value_(nullptr), allocator_(nullptr) {} + + // Construct a top-level JSON document. + explicit Value(const ValueType type) + : document_(static_cast(type)), value_(nullptr), + allocator_(&document_.GetAllocator()) + { + } + + // Construct a non-top-level JSON value in a 'document'. + explicit Value(TritonJson::Value& document, const ValueType type) + { + allocator_ = &document.document_.GetAllocator(); + value_ = new (allocator_->Malloc(sizeof(rapidjson::Value))) + rapidjson::Value(static_cast(type)); + } + + // Move constructor. + explicit Value(Value&& other) { *this = std::move(other); } + + // Move assignment operator. + Value& operator=(Value&& other) + { + document_ = std::move(other.document_); + value_ = other.value_; + allocator_ = other.allocator_; + other.value_ = nullptr; + other.allocator_ = nullptr; + return *this; + } + + // Parse JSON into document. Can only be called on top-level + // document value, otherwise error is returned. + TRITONJSON_STATUSTYPE Parse(const char* base, const size_t size) + { + if (value_ != nullptr) { + TRITONJSON_STATUSRETURN( + std::string("JSON parsing only available for top-level document")); + } + const unsigned int parseFlags = rapidjson::kParseNanAndInfFlag; + document_.Parse(base, size); + if (document_.HasParseError()) { + TRITONJSON_STATUSRETURN(std::string( + "failed to parse the request JSON buffer: " + + std::string(GetParseError_En(document_.GetParseError())) + " at " + + std::to_string(document_.GetErrorOffset()))); + } + allocator_ = &document_.GetAllocator(); + return TRITONJSON_STATUSSUCCESS; + } + + // \see Parse(const char* base, const size_t size) + TRITONJSON_STATUSTYPE Parse(const std::string& json) + { + return Parse(json.data(), json.size()); + } + + // Write JSON representation into a 'buffer' in a compact + // format. Can only be called for a top-level document value, + // otherwise error is returned. + TRITONJSON_STATUSTYPE Write(WriteBuffer* buffer) const + { + if (value_ != nullptr) { + TRITONJSON_STATUSRETURN( + std::string("JSON writing only available for top-level document")); + } + const unsigned int writeFlags = rapidjson::kWriteNanAndInfFlag; + // Provide default template arguments to pass writeFlags + rapidjson::Writer< + WriteBuffer, rapidjson::UTF8<>, rapidjson::UTF8<>, + rapidjson::CrtAllocator, writeFlags> + writer(*buffer); + if (!document_.Accept(writer)) { + TRITONJSON_STATUSRETURN( + std::string("Failed to accept document, invalid JSON.")); + } + return TRITONJSON_STATUSSUCCESS; + } + + // Write JSON representation into a 'buffer' in an easy-to-read + // format. Can only be called for a top-level document value, + // otherwise error is returned. + TRITONJSON_STATUSTYPE PrettyWrite(WriteBuffer* buffer) const + { + if (value_ != nullptr) { + TRITONJSON_STATUSRETURN( + std::string("JSON writing only available for top-level document")); + } + + // Can't pass writeFlags with latest release v1.1.0 of rapidjson-dev. + // We would need to build rapidjson from source to capture latest fixes. + // See this issue: + // https://github.com/Tencent/rapidjson/issues/905#issuecomment-370981353 + // PrettyWrite is only used for displaying model configs currently, so + // this should not be an issue. + rapidjson::PrettyWriter writer(*buffer); + if (!document_.Accept(writer)) { + TRITONJSON_STATUSRETURN( + std::string("Failed to accept document, invalid JSON.")); + } + return TRITONJSON_STATUSSUCCESS; + } + + // Swap a value with another. + TRITONJSON_STATUSTYPE Swap(TritonJson::Value& other) + { + rapidjson::Value& value = AsMutableValue(); + value.Swap(other.AsMutableValue()); + return TRITONJSON_STATUSSUCCESS; + } + + // FIXME Should have Set* for all types. + + // Set/overwrite a signed integer in a value. This changes the + // type of the value to signed int. + TRITONJSON_STATUSTYPE SetInt(const int64_t value) + { + rapidjson::Value& v = AsMutableValue(); + v.SetInt64(value); + return TRITONJSON_STATUSSUCCESS; + } + + // Set/overwrite a string in a value. This changes the + // type of the value to string + TRITONJSON_STATUSTYPE SetString(const std::string& value) + { + rapidjson::Value& v = AsMutableValue(); + v.SetString(value.c_str(), value.length(), *allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Set/overwrite a string member with provided name and value in this object + TRITONJSON_STATUSTYPE SetStringObject( + const char* name, const std::string& value) + { + rapidjson::Value& object = AsMutableValue(); + if (!object.IsObject()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to add/replace JSON member '") + name + + "' to non-object"); + } + auto itr = object.FindMember(name); + if (itr == object.MemberEnd()) { + AddString(name, value); + } else { + object.RemoveMember(itr); + object.AddMember( + rapidjson::Value(rapidjson::StringRef(name)).Move(), + rapidjson::Value(value.c_str(), value.size(), *allocator_), + *allocator_); + } + + return TRITONJSON_STATUSSUCCESS; + } + + // Add an array or object as a new member to this value. 'value' + // is moved into this value and so on return 'value' should not be + // used. It is assumed that 'name' can be used by reference, it is + // the caller's responsibility to make sure the lifetime of 'name' + // extends at least as long as the object. + TRITONJSON_STATUSTYPE Add(const char* name, TritonJson::Value&& value) + { + rapidjson::Value& object = AsMutableValue(); + if (!object.IsObject()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to add JSON member '") + name + + "' to non-object"); + } + if (value.value_ == nullptr) { + rapidjson::Value v2; + v2.CopyFrom(value.document_, *allocator_); + object.AddMember( + rapidjson::Value(rapidjson::StringRef(name)).Move(), v2.Move(), + *allocator_); + } else { + object.AddMember( + rapidjson::Value(rapidjson::StringRef(name)).Move(), + value.value_->Move(), *allocator_); + } + value.Release(); + return TRITONJSON_STATUSSUCCESS; + } + + // Add a copy of a string as a new member to this value. It is + // assumed that 'name' can be used by reference, it is the + // caller's responsibility to make sure the lifetime of 'name' + // extends at least as long as the object. + TRITONJSON_STATUSTYPE AddString(const char* name, const std::string& value) + { + rapidjson::Value& object = AsMutableValue(); + if (!object.IsObject()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to add JSON member '") + name + + "' to non-object"); + } + object.AddMember( + rapidjson::Value(rapidjson::StringRef(name)).Move(), + rapidjson::Value(value.c_str(), value.size(), *allocator_).Move(), + *allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Add a copy of a explicit-length string as a new member to this + // value. It is assumed that 'name' can be used by reference, it + // is the caller's responsibility to make sure the lifetime of + // 'name' extends at least as long as the object. + TRITONJSON_STATUSTYPE AddString( + const char* name, const char* value, const size_t len) + { + rapidjson::Value& object = AsMutableValue(); + if (!object.IsObject()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to add JSON member '") + name + + "' to non-object"); + } + object.AddMember( + rapidjson::Value(rapidjson::StringRef(name)).Move(), + rapidjson::Value(value, len, *allocator_).Move(), *allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Add a reference to a string as a new member to this value. It + // is assumed that 'name' and 'value' can be used by reference, it + // is the caller's responsibility to make sure the lifetime of + // 'name' and 'value' extend at least as long as the object. + TRITONJSON_STATUSTYPE AddStringRef(const char* name, const char* value) + { + rapidjson::Value& object = AsMutableValue(); + if (!object.IsObject()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to add JSON member '") + name + + "' to non-object"); + } + object.AddMember( + rapidjson::Value(rapidjson::StringRef(name)).Move(), + rapidjson::StringRef(value), *allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Add a reference to a expicit-length string as a new member to + // this value. It is assumed that 'name' and 'value' can be used + // by reference, it is the caller's responsibility to make sure + // the lifetime of 'name' and 'value' extend at least as long as + // the object. + TRITONJSON_STATUSTYPE AddStringRef( + const char* name, const char* value, const size_t len) + { + rapidjson::Value& object = AsMutableValue(); + if (!object.IsObject()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to add JSON member '") + name + + "' to non-object"); + } + object.AddMember( + rapidjson::Value(rapidjson::StringRef(name)).Move(), + rapidjson::StringRef(value, len), *allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Add a boolean new member to this value. It is assumed that + // 'name' can be used by reference, it is the caller's + // responsibility to make sure the lifetime of 'name' extends at + // least as long as the object. + TRITONJSON_STATUSTYPE AddBool(const char* name, const bool value) + { + rapidjson::Value& object = AsMutableValue(); + if (!object.IsObject()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to add JSON member '") + name + + "' to non-object"); + } + object.AddMember( + rapidjson::Value(rapidjson::StringRef(name)).Move(), + rapidjson::Value(value).Move(), *allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Add a signed integer as a new member to this value. It is + // assumed that 'name' can be used by reference, it is the + // caller's responsibility to make sure the lifetime of 'name' + // extends at least as long as the object. + TRITONJSON_STATUSTYPE AddInt(const char* name, const int64_t value) + { + rapidjson::Value& object = AsMutableValue(); + if (!object.IsObject()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to add JSON member '") + name + + "' to non-object"); + } + object.AddMember( + rapidjson::Value(rapidjson::StringRef(name)).Move(), + rapidjson::Value(value).Move(), *allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Add an unsigned integer as a new member to this value. It is + // assumed that 'name' can be used by reference, it is the + // caller's responsibility to make sure the lifetime of 'name' + // extends at least as long as the object. + TRITONJSON_STATUSTYPE AddUInt(const char* name, const uint64_t value) + { + rapidjson::Value& object = AsMutableValue(); + if (!object.IsObject()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to add JSON member '") + name + + "' to non-object"); + } + object.AddMember( + rapidjson::Value(rapidjson::StringRef(name)).Move(), + rapidjson::Value(value).Move(), *allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Add a double as a new member to this value. It is assumed that + // 'name' can be used by reference, it is the caller's + // responsibility to make sure the lifetime of 'name' extends at + // least as long as the object. + TRITONJSON_STATUSTYPE AddDouble(const char* name, const double value) + { + rapidjson::Value& object = AsMutableValue(); + if (!object.IsObject()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to add JSON member '") + name + + "' to non-object"); + } + object.AddMember( + rapidjson::Value(rapidjson::StringRef(name)).Move(), + rapidjson::Value(value).Move(), *allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Append an array or object to this value, which must be an + // array. 'value' is moved into this value and so on return + // 'value' should not be used. + TRITONJSON_STATUSTYPE Append(TritonJson::Value&& value) + { + rapidjson::Value& array = AsMutableValue(); + if (!array.IsArray()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to append JSON member to non-array")); + } + if (value.value_ == nullptr) { + rapidjson::Value v2; + v2.CopyFrom(value.document_, *allocator_); + array.PushBack(v2.Move(), *allocator_); + } else { + array.PushBack(value.value_->Move(), *allocator_); + } + + value.Release(); + return TRITONJSON_STATUSSUCCESS; + } + + // Append a copy of a string to this value, which must be an + // array. + TRITONJSON_STATUSTYPE AppendString(const std::string& value) + { + rapidjson::Value& array = AsMutableValue(); + if (!array.IsArray()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to append JSON member to non-array")); + } + array.PushBack( + rapidjson::Value(value.c_str(), value.size(), *allocator_).Move(), + *allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Append a copy of an explicit-length string to this value, which + // must be an array. + TRITONJSON_STATUSTYPE AppendString(const char* value, const size_t len) + { + rapidjson::Value& array = AsMutableValue(); + if (!array.IsArray()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to append JSON member to non-array")); + } + array.PushBack( + rapidjson::Value(value, len, *allocator_).Move(), *allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Append a reference to a string to this value, which must be an + // array. It is assumed that 'value' can be used by reference, it + // is the caller's responsibility to make sure the lifetime of + // 'value' extends at least as long as the object. + TRITONJSON_STATUSTYPE AppendStringRef(const char* value) + { + rapidjson::Value& array = AsMutableValue(); + if (!array.IsArray()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to append JSON member to non-array")); + } + array.PushBack(rapidjson::StringRef(value), *allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Append a reference to a expicit-length string to this value, + // which must be an array. It is assumed that 'value' can be used + // by reference, it is the caller's responsibility to make sure + // the lifetime of 'value' extends at least as long as the object. + TRITONJSON_STATUSTYPE AppendStringRef(const char* value, const size_t len) + { + rapidjson::Value& array = AsMutableValue(); + if (!array.IsArray()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to append JSON member to non-array")); + } + array.PushBack(rapidjson::StringRef(value, len), *allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Append a boolean to this value, which must be an array. + TRITONJSON_STATUSTYPE AppendBool(const bool value) + { + rapidjson::Value& array = AsMutableValue(); + if (!array.IsArray()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to append JSON member to non-array")); + } + + array.PushBack(rapidjson::Value(value).Move(), *allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Append a signed integer to this value, which must be an array. + TRITONJSON_STATUSTYPE AppendInt(const int64_t value) + { + rapidjson::Value& array = AsMutableValue(); + if (!array.IsArray()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to append JSON member to non-array")); + } + + array.PushBack(rapidjson::Value(value).Move(), *allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Append an unsigned integer to this value, which must be an + // array. + TRITONJSON_STATUSTYPE AppendUInt(const uint64_t value) + { + rapidjson::Value& array = AsMutableValue(); + if (!array.IsArray()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to append JSON member to non-array")); + } + + array.PushBack(rapidjson::Value(value).Move(), *allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Append a double to this value, which must be an array. + TRITONJSON_STATUSTYPE AppendDouble(const double value) + { + rapidjson::Value& array = AsMutableValue(); + if (!array.IsArray()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to append JSON member to non-array")); + } + + array.PushBack(rapidjson::Value(value).Move(), *allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Remove member from this object + TRITONJSON_STATUSTYPE Remove(const char* name) + { + rapidjson::Value& object = AsMutableValue(); + if (!object.IsObject()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to remove JSON member '") + name + + "' to non-object"); + } + auto itr = object.FindMember(name); + if (itr != object.MemberEnd()) { + object.RemoveMember(itr); + } // else report success + + return TRITONJSON_STATUSSUCCESS; + } + + // Check if this value is of the specified type. Return appropriate + // error if not. + TRITONJSON_STATUSTYPE AssertType(TritonJson::ValueType type) const + { + if (static_cast(type) != AsValue().GetType()) { + TRITONJSON_STATUSRETURN(std::string("unexpected type")); + } + return TRITONJSON_STATUSSUCCESS; + } + + // Get the size of an array. If called on non-array returns zero. + size_t ArraySize() const + { + const rapidjson::Value& array = AsValue(); + if (!array.IsArray()) { + return 0; + } + return array.GetArray().Size(); + } + + // Return the specified index contained in this array. + TRITONJSON_STATUSTYPE At( + const size_t idx, TritonJson::Value* value = nullptr) + { + rapidjson::Value& array = AsMutableValue(); + if (!array.IsArray() || (idx >= array.GetArray().Size())) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access non-existing array index '") + + std::to_string(idx) + "'"); + } + *value = TritonJson::Value(array[idx], allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Get the names of all members in an object. Error if value is + // not an object. + TRITONJSON_STATUSTYPE Members(std::vector* names) const + { + const rapidjson::Value& object = AsValue(); + if (!object.IsObject()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to get members for non-object")); + } + for (const auto& m : object.GetObject()) { + names->push_back(m.name.GetString()); + } + return TRITONJSON_STATUSSUCCESS; + } + + // Return true if this value is an object and the named member is + // contained in this object. + bool Find(const char* name) const + { + const rapidjson::Value& object = AsValue(); + return object.IsObject() && object.HasMember(name); + } + + // Return true if this value is an object and the named member is + // contained in this object. Return the member in 'value'. + bool Find(const char* name, TritonJson::Value* value) + { + rapidjson::Value& object = AsMutableValue(); + if (object.IsObject() && object.HasMember(name)) { + if (value != nullptr) { + *value = TritonJson::Value(object[name], allocator_); + } + return true; + } + + return false; + } + + // Whether the object is null value. Note that false will also be retuned + // if the object is not a JSON value. + bool IsNull() const { return ((value_ != nullptr) && value_->IsNull()); } + + // Return true if the object is an object and it has no members; + // false otherwise. + bool IsEmpty() const + { + const rapidjson::Value& object = AsValue(); + if (object.IsObject() && object.MemberCount() == 0) { + return true; + } + return false; + } + + // Get value as a string. The string may contain null or other + // special characters and so 'len' must be used to determine length. + // Error if value is not a string. + TRITONJSON_STATUSTYPE AsString(const char** value, size_t* len) const + { + if ((value_ == nullptr) || !value_->IsString()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access JSON non-string as string")); + } + *value = value_->GetString(); + *len = value_->GetStringLength(); + return TRITONJSON_STATUSSUCCESS; + } + + // Get value as a string. The string may contain null or other + // special characters. Error if value is not a string. + TRITONJSON_STATUSTYPE AsString(std::string* str) const + { + if ((value_ == nullptr) || !value_->IsString()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access JSON non-string as string")); + } + str->assign(value_->GetString(), value_->GetStringLength()); + return TRITONJSON_STATUSSUCCESS; + } + + // Get value as a boolean. Error if value is not a boolean. + TRITONJSON_STATUSTYPE AsBool(bool* value) const + { + if ((value_ == nullptr) || !value_->IsBool()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access JSON non-boolean as boolean")); + } + *value = value_->GetBool(); + return TRITONJSON_STATUSSUCCESS; + } + + // Get value as a signed integer. Error if value is not a signed + // integer. + TRITONJSON_STATUSTYPE AsInt(int64_t* value) const + { + if ((value_ == nullptr) || !value_->IsInt64()) { + TRITONJSON_STATUSRETURN(std::string( + "attempt to access JSON non-signed-integer as signed-integer")); + } + *value = value_->GetInt64(); + return TRITONJSON_STATUSSUCCESS; + } + + // Get value as an unsigned integer. Error if value is not an + // unsigned integer. + TRITONJSON_STATUSTYPE AsUInt(uint64_t* value) const + { + if ((value_ == nullptr) || !value_->IsUint64()) { + TRITONJSON_STATUSRETURN(std::string( + "attempt to access JSON non-unsigned-integer as unsigned-integer")); + } + *value = value_->GetUint64(); + return TRITONJSON_STATUSSUCCESS; + } + + // Get value as a double. Error if value is not a double. + TRITONJSON_STATUSTYPE AsDouble(double* value) const + { + if ((value_ == nullptr) || !value_->IsNumber()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access JSON non-number as double")); + } + *value = value_->GetDouble(); + return TRITONJSON_STATUSSUCCESS; + } + + // Get named array member contained in this object. + TRITONJSON_STATUSTYPE MemberAsArray( + const char* name, TritonJson::Value* value) + { + rapidjson::Value& object = AsMutableValue(); + if (!object.IsObject() || !object.HasMember(name)) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access non-existing object member '") + + name + "'"); + } + auto& v = object[name]; + if (!v.IsArray()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access JSON non-array as array")); + } + *value = TritonJson::Value(v, allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Get named object member contained in this object. + TRITONJSON_STATUSTYPE MemberAsObject( + const char* name, TritonJson::Value* value) + { + rapidjson::Value& object = AsMutableValue(); + if (!object.IsObject() || !object.HasMember(name)) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access non-existing object member '") + + name + "'"); + } + auto& v = object[name]; + if (!v.IsObject()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access JSON non-object as object")); + } + *value = TritonJson::Value(v, allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Get object member as a string. The string may contain null or other + // special characters and so 'len' must be used to determine length. + // Error if this is not an object or if the member is not a string. + TRITONJSON_STATUSTYPE MemberAsString( + const char* name, const char** value, size_t* len) const + { + const rapidjson::Value& object = AsValue(); + if (!object.IsObject() || !object.HasMember(name)) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access non-existing object member '") + + name + "'"); + } + const auto& v = object[name]; + if (!v.IsString()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access JSON non-string as string")); + } + *value = v.GetString(); + *len = v.GetStringLength(); + return TRITONJSON_STATUSSUCCESS; + } + + // Get object member as a string. The string may contain null or + // other special characters. Error if this is not an object or if + // the member is not a string. + TRITONJSON_STATUSTYPE MemberAsString( + const char* name, std::string* str) const + { + const rapidjson::Value& object = AsValue(); + if (!object.IsObject() || !object.HasMember(name)) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access non-existing object member '") + + name + "'"); + } + const auto& v = object[name]; + if (!v.IsString()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access JSON non-string as string")); + } + str->assign(v.GetString(), v.GetStringLength()); + return TRITONJSON_STATUSSUCCESS; + } + + // Get object member as a boolean. Error if this is not an object + // or if the member is not a boolean. + TRITONJSON_STATUSTYPE MemberAsBool(const char* name, bool* value) const + { + const rapidjson::Value& object = AsValue(); + if (!object.IsObject() || !object.HasMember(name)) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access non-existing object member '") + + name + "'"); + } + const auto& v = object[name]; + if (!v.IsBool()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access JSON non-boolean as boolean")); + } + *value = v.GetBool(); + return TRITONJSON_STATUSSUCCESS; + } + + // Get object member as a signed integer. Error if this is not an object + // or if the member is not a signed integer. + TRITONJSON_STATUSTYPE MemberAsInt(const char* name, int64_t* value) const + { + const rapidjson::Value& object = AsValue(); + if (!object.IsObject() || !object.HasMember(name)) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access non-existing object member '") + + name + "'"); + } + const auto& v = object[name]; + if (!v.IsInt64()) { + TRITONJSON_STATUSRETURN(std::string( + "attempt to access JSON non-signed-integer as signed-integer")); + } + *value = v.GetInt64(); + return TRITONJSON_STATUSSUCCESS; + } + + // Get object member as an unsigned integer. Error if this is not an object + // or if the member is not an unsigned integer. + TRITONJSON_STATUSTYPE MemberAsUInt(const char* name, uint64_t* value) const + { + const rapidjson::Value& object = AsValue(); + if (!object.IsObject() || !object.HasMember(name)) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access non-existing object member '") + + name + "'"); + } + const auto& v = object[name]; + if (!v.IsUint64()) { + TRITONJSON_STATUSRETURN(std::string( + "attempt to access JSON non-unsigned-integer as unsigned-integer")); + } + *value = v.GetUint64(); + return TRITONJSON_STATUSSUCCESS; + } + + // Get object member as a double. Error if this is not an object + // or if the member is not a double. + TRITONJSON_STATUSTYPE MemberAsDouble(const char* name, double* value) const + { + const rapidjson::Value& object = AsValue(); + if (!object.IsObject() || !object.HasMember(name)) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access non-existing object member '") + + name + "'"); + } + const auto& v = object[name]; + if (!v.IsNumber()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access JSON non-number as double")); + } + *value = v.GetDouble(); + return TRITONJSON_STATUSSUCCESS; + } + + // Get array element at a given index within this array. + TRITONJSON_STATUSTYPE IndexAsArray( + const size_t idx, TritonJson::Value* value) + { + rapidjson::Value& array = AsMutableValue(); + if (!array.IsArray() || (idx >= array.GetArray().Size())) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access non-existing array index '") + + std::to_string(idx) + "'"); + } + auto& v = array[idx]; + if (!v.IsArray()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access JSON non-array as array")); + } + *value = TritonJson::Value(v, allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Get object element at a given index within this array. + TRITONJSON_STATUSTYPE IndexAsObject( + const size_t idx, TritonJson::Value* value) + { + rapidjson::Value& array = AsMutableValue(); + if (!array.IsArray() || (idx >= array.GetArray().Size())) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access non-existing array index '") + + std::to_string(idx) + "'"); + } + auto& v = array[idx]; + if (!v.IsObject()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access JSON non-object as object")); + } + *value = TritonJson::Value(v, allocator_); + return TRITONJSON_STATUSSUCCESS; + } + + // Get array index as a string. The string may contain null or + // other special characters and so 'len' must be used to determine + // length. Error if this is not an array or if the index element + // is not a string. + TRITONJSON_STATUSTYPE IndexAsString( + const size_t idx, const char** value, size_t* len) const + { + const rapidjson::Value& array = AsValue(); + if (!array.IsArray() || (idx >= array.GetArray().Size())) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access non-existing array index '") + + std::to_string(idx) + "'"); + } + const auto& v = array[idx]; + if (!v.IsString()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access JSON non-string as string")); + } + *value = v.GetString(); + *len = v.GetStringLength(); + return TRITONJSON_STATUSSUCCESS; + } + + // Get array index as a string. The string may contain null or + // other special characters. Error if this is not an array or if + // the index element is not a string. + TRITONJSON_STATUSTYPE IndexAsString( + const size_t idx, std::string* str) const + { + const rapidjson::Value& array = AsValue(); + if (!array.IsArray() || (idx >= array.GetArray().Size())) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access non-existing array index '") + + std::to_string(idx) + "'"); + } + const auto& v = array[idx]; + if (!v.IsString()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access JSON non-string as string")); + } + str->assign(v.GetString(), v.GetStringLength()); + return TRITONJSON_STATUSSUCCESS; + } + + // Get array index as a boolean. Error if this is not an array or + // if the index element is not a boolean. + TRITONJSON_STATUSTYPE IndexAsBool(const size_t idx, bool* value) const + { + const rapidjson::Value& array = AsValue(); + if (!array.IsArray() || (idx >= array.GetArray().Size())) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access non-existing array index '") + + std::to_string(idx) + "'"); + } + const auto& v = array[idx]; + if (!v.IsBool()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access JSON non-boolean as boolean")); + } + *value = v.GetBool(); + return TRITONJSON_STATUSSUCCESS; + } + + // Get array index as a signed integer. Error if this is not an array or + // if the index element is not a signed integer. + TRITONJSON_STATUSTYPE IndexAsInt(const size_t idx, int64_t* value) const + { + const rapidjson::Value& array = AsValue(); + if (!array.IsArray() || (idx >= array.GetArray().Size())) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access non-existing array index '") + + std::to_string(idx) + "'"); + } + const auto& v = array[idx]; + if (!v.IsInt64()) { + TRITONJSON_STATUSRETURN(std::string( + "attempt to access JSON non-signed-integer as signed-integer")); + } + *value = v.GetInt64(); + return TRITONJSON_STATUSSUCCESS; + } + + // Get array index as an unsigned integer. Error if this is not an array or + // if the index element is not an unsigned integer. + TRITONJSON_STATUSTYPE IndexAsUInt(const size_t idx, uint64_t* value) const + { + const rapidjson::Value& array = AsValue(); + if (!array.IsArray() || (idx >= array.GetArray().Size())) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access non-existing array index '") + + std::to_string(idx) + "'"); + } + const auto& v = array[idx]; + if (!v.IsUint64()) { + TRITONJSON_STATUSRETURN(std::string( + "attempt to access JSON non-unsigned-integer as unsigned-integer")); + } + *value = v.GetUint64(); + return TRITONJSON_STATUSSUCCESS; + } + + // Get array index as a double. Error if this is not an array or + // if the index element is not a double. + TRITONJSON_STATUSTYPE IndexAsDouble(const size_t idx, double* value) const + { + const rapidjson::Value& array = AsValue(); + if (!array.IsArray() || (idx >= array.GetArray().Size())) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access non-existing array index '") + + std::to_string(idx) + "'"); + } + const auto& v = array[idx]; + if (!v.IsNumber()) { + TRITONJSON_STATUSRETURN( + std::string("attempt to access JSON non-number as double")); + } + *value = v.GetDouble(); + return TRITONJSON_STATUSSUCCESS; + } + + // Release/clear a value. + void Release() + { + if (value_ != nullptr) { + allocator_->Free(value_); + } + } + + private: + // Construct a non-top-level JSON value that references an + // existing element in a document. + explicit Value( + rapidjson::Value& v, rapidjson::Document::AllocatorType* allocator) + : value_(&v), allocator_(allocator) + { + } + + // Return a value object that can be used for both a top-level + // document as well as an element within a document. + const rapidjson::Value& AsValue() const + { + if (value_ == nullptr) { + return document_; + } + return *value_; + } + + rapidjson::Value& AsMutableValue() + { + if (value_ == nullptr) { + return document_; + } + return *value_; + } + + // If this object a document or value. Based on this only one or + // document_ or value_ is valid. + rapidjson::Document document_; + rapidjson::Value* value_; + rapidjson::Document::AllocatorType* allocator_; + }; +}; + +}} // namespace triton::common diff --git a/3rdparty/common-r22.12/protobuf/grpc_service.proto b/3rdparty/common-r22.12/protobuf/grpc_service.proto new file mode 100644 index 0000000000000000000000000000000000000000..b86ba13d4aad9c40a7c58aa530091cd48256607d --- /dev/null +++ b/3rdparty/common-r22.12/protobuf/grpc_service.proto @@ -0,0 +1,1699 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +syntax = "proto3"; + +package inference; + +//@@.. cpp:namespace:: inference + +import "model_config.proto"; + +//@@ +//@@.. cpp:var:: service InferenceService +//@@ +//@@ Inference Server GRPC endpoints. +//@@ +service GRPCInferenceService +{ + //@@ .. cpp:var:: rpc ServerLive(ServerLiveRequest) returns + //@@ (ServerLiveResponse) + //@@ + //@@ Check liveness of the inference server. + //@@ + rpc ServerLive(ServerLiveRequest) returns (ServerLiveResponse) {} + + //@@ .. cpp:var:: rpc ServerReady(ServerReadyRequest) returns + //@@ (ServerReadyResponse) + //@@ + //@@ Check readiness of the inference server. + //@@ + rpc ServerReady(ServerReadyRequest) returns (ServerReadyResponse) {} + + //@@ .. cpp:var:: rpc ModelReady(ModelReadyRequest) returns + //@@ (ModelReadyResponse) + //@@ + //@@ Check readiness of a model in the inference server. + //@@ + rpc ModelReady(ModelReadyRequest) returns (ModelReadyResponse) {} + + //@@ .. cpp:var:: rpc ServerMetadata(ServerMetadataRequest) returns + //@@ (ServerMetadataResponse) + //@@ + //@@ Get server metadata. + //@@ + rpc ServerMetadata(ServerMetadataRequest) returns (ServerMetadataResponse) {} + + //@@ .. cpp:var:: rpc ModelMetadata(ModelMetadataRequest) returns + //@@ (ModelMetadataResponse) + //@@ + //@@ Get model metadata. + //@@ + rpc ModelMetadata(ModelMetadataRequest) returns (ModelMetadataResponse) {} + + //@@ .. cpp:var:: rpc ModelInfer(ModelInferRequest) returns + //@@ (ModelInferResponse) + //@@ + //@@ Perform inference using a specific model. + //@@ + rpc ModelInfer(ModelInferRequest) returns (ModelInferResponse) {} + + //@@ .. cpp:var:: rpc ModelStreamInfer(stream ModelInferRequest) returns + //@@ (stream ModelStreamInferResponse) + //@@ + //@@ Perform streaming inference. + //@@ + rpc ModelStreamInfer(stream ModelInferRequest) + returns (stream ModelStreamInferResponse) + { + } + + //@@ .. cpp:var:: rpc ModelConfig(ModelConfigRequest) returns + //@@ (ModelConfigResponse) + //@@ + //@@ Get model configuration. + //@@ + rpc ModelConfig(ModelConfigRequest) returns (ModelConfigResponse) {} + + //@@ .. cpp:var:: rpc ModelStatistics( + //@@ ModelStatisticsRequest) + //@@ returns (ModelStatisticsResponse) + //@@ + //@@ Get the cumulative inference statistics for a model. + //@@ + rpc ModelStatistics(ModelStatisticsRequest) returns (ModelStatisticsResponse) + { + } + + //@@ .. cpp:var:: rpc RepositoryIndex(RepositoryIndexRequest) returns + //@@ (RepositoryIndexResponse) + //@@ + //@@ Get the index of model repository contents. + //@@ + rpc RepositoryIndex(RepositoryIndexRequest) returns (RepositoryIndexResponse) + { + } + + //@@ .. cpp:var:: rpc RepositoryModelLoad(RepositoryModelLoadRequest) returns + //@@ (RepositoryModelLoadResponse) + //@@ + //@@ Load or reload a model from a repository. + //@@ + rpc RepositoryModelLoad(RepositoryModelLoadRequest) + returns (RepositoryModelLoadResponse) + { + } + + //@@ .. cpp:var:: rpc RepositoryModelUnload(RepositoryModelUnloadRequest) + //@@ returns (RepositoryModelUnloadResponse) + //@@ + //@@ Unload a model. + //@@ + rpc RepositoryModelUnload(RepositoryModelUnloadRequest) + returns (RepositoryModelUnloadResponse) + { + } + + //@@ .. cpp:var:: rpc SystemSharedMemoryStatus( + //@@ SystemSharedMemoryStatusRequest) + //@@ returns (SystemSharedMemoryStatusRespose) + //@@ + //@@ Get the status of all registered system-shared-memory regions. + //@@ + rpc SystemSharedMemoryStatus(SystemSharedMemoryStatusRequest) + returns (SystemSharedMemoryStatusResponse) + { + } + + //@@ .. cpp:var:: rpc SystemSharedMemoryRegister( + //@@ SystemSharedMemoryRegisterRequest) + //@@ returns (SystemSharedMemoryRegisterResponse) + //@@ + //@@ Register a system-shared-memory region. + //@@ + rpc SystemSharedMemoryRegister(SystemSharedMemoryRegisterRequest) + returns (SystemSharedMemoryRegisterResponse) + { + } + + //@@ .. cpp:var:: rpc SystemSharedMemoryUnregister( + //@@ SystemSharedMemoryUnregisterRequest) + //@@ returns (SystemSharedMemoryUnregisterResponse) + //@@ + //@@ Unregister a system-shared-memory region. + //@@ + rpc SystemSharedMemoryUnregister(SystemSharedMemoryUnregisterRequest) + returns (SystemSharedMemoryUnregisterResponse) + { + } + + //@@ .. cpp:var:: rpc CudaSharedMemoryStatus( + //@@ CudaSharedMemoryStatusRequest) + //@@ returns (CudaSharedMemoryStatusRespose) + //@@ + //@@ Get the status of all registered CUDA-shared-memory regions. + //@@ + rpc CudaSharedMemoryStatus(CudaSharedMemoryStatusRequest) + returns (CudaSharedMemoryStatusResponse) + { + } + + //@@ .. cpp:var:: rpc CudaSharedMemoryRegister( + //@@ CudaSharedMemoryRegisterRequest) + //@@ returns (CudaSharedMemoryRegisterResponse) + //@@ + //@@ Register a CUDA-shared-memory region. + //@@ + rpc CudaSharedMemoryRegister(CudaSharedMemoryRegisterRequest) + returns (CudaSharedMemoryRegisterResponse) + { + } + + //@@ .. cpp:var:: rpc CudaSharedMemoryUnregister( + //@@ CudaSharedMemoryUnregisterRequest) + //@@ returns (CudaSharedMemoryUnregisterResponse) + //@@ + //@@ Unregister a CUDA-shared-memory region. + //@@ + rpc CudaSharedMemoryUnregister(CudaSharedMemoryUnregisterRequest) + returns (CudaSharedMemoryUnregisterResponse) + { + } + + //@@ .. cpp:var:: rpc TraceSetting(TraceSettingRequest) + //@@ returns (TraceSettingResponse) + //@@ + //@@ Update and get the trace setting of the Triton server. + //@@ + rpc TraceSetting(TraceSettingRequest) returns (TraceSettingResponse) + { + } + + //@@ .. cpp:var:: rpc LogSettings(LogSettingsRequest) + //@@ returns (LogSettingsResponse) + //@@ + //@@ Update and get the log settings of the Triton server. + //@@ + rpc LogSettings(LogSettingsRequest) returns (LogSettingsResponse) + { + } +} + +//@@ +//@@.. cpp:var:: message ServerLiveRequest +//@@ +//@@ Request message for ServerLive. +//@@ +message ServerLiveRequest {} + +//@@ +//@@.. cpp:var:: message ServerLiveResponse +//@@ +//@@ Response message for ServerLive. +//@@ +message ServerLiveResponse +{ + //@@ + //@@ .. cpp:var:: bool live + //@@ + //@@ True if the inference server is live, false it not live. + //@@ + bool live = 1; +} + +//@@ +//@@.. cpp:var:: message ServerReadyRequest +//@@ +//@@ Request message for ServerReady. +//@@ +message ServerReadyRequest {} + +//@@ +//@@.. cpp:var:: message ServerReadyResponse +//@@ +//@@ Response message for ServerReady. +//@@ +message ServerReadyResponse +{ + //@@ + //@@ .. cpp:var:: bool ready + //@@ + //@@ True if the inference server is ready, false it not ready. + //@@ + bool ready = 1; +} + +//@@ +//@@.. cpp:var:: message ModelReadyRequest +//@@ +//@@ Request message for ModelReady. +//@@ +message ModelReadyRequest +{ + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the model to check for readiness. + //@@ + string name = 1; + + //@@ .. cpp:var:: string version + //@@ + //@@ The version of the model to check for readiness. If not given the + //@@ server will choose a version based on the model and internal policy. + //@@ + string version = 2; +} + +//@@ +//@@.. cpp:var:: message ModelReadyResponse +//@@ +//@@ Response message for ModelReady. +//@@ +message ModelReadyResponse +{ + //@@ + //@@ .. cpp:var:: bool ready + //@@ + //@@ True if the model is ready, false it not ready. + //@@ + bool ready = 1; +} + +//@@ +//@@.. cpp:var:: message ServerMetadataRequest +//@@ +//@@ Request message for ServerMetadata. +//@@ +message ServerMetadataRequest {} + +//@@ +//@@.. cpp:var:: message ServerMetadataResponse +//@@ +//@@ Response message for ServerMetadata. +//@@ +message ServerMetadataResponse +{ + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The server name. + //@@ + string name = 1; + + //@@ + //@@ .. cpp:var:: string version + //@@ + //@@ The server version. + //@@ + string version = 2; + + //@@ + //@@ .. cpp:var:: string extensions (repeated) + //@@ + //@@ The extensions supported by the server. + //@@ + repeated string extensions = 3; +} + +//@@ +//@@.. cpp:var:: message ModelMetadataRequest +//@@ +//@@ Request message for ModelMetadata. +//@@ +message ModelMetadataRequest +{ + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the model. + //@@ + string name = 1; + + //@@ .. cpp:var:: string version + //@@ + //@@ The version of the model to check for readiness. If not + //@@ given the server will choose a version based on the + //@@ model and internal policy. + //@@ + string version = 2; +} + +//@@ +//@@.. cpp:var:: message ModelMetadataResponse +//@@ +//@@ Response message for ModelMetadata. +//@@ +message ModelMetadataResponse +{ + //@@ + //@@ .. cpp:var:: message TensorMetadata + //@@ + //@@ Metadata for a tensor. + //@@ + message TensorMetadata + { + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The tensor name. + //@@ + string name = 1; + + //@@ + //@@ .. cpp:var:: string datatype + //@@ + //@@ The tensor data type. + //@@ + string datatype = 2; + + //@@ + //@@ .. cpp:var:: int64 shape (repeated) + //@@ + //@@ The tensor shape. A variable-size dimension is represented + //@@ by a -1 value. + //@@ + repeated int64 shape = 3; + } + + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The model name. + //@@ + string name = 1; + + //@@ + //@@ .. cpp:var:: string versions (repeated) + //@@ + //@@ The versions of the model. + //@@ + repeated string versions = 2; + + //@@ + //@@ .. cpp:var:: string platform + //@@ + //@@ The model's platform. + //@@ + string platform = 3; + + //@@ + //@@ .. cpp:var:: TensorMetadata inputs (repeated) + //@@ + //@@ The model's inputs. + //@@ + repeated TensorMetadata inputs = 4; + + //@@ + //@@ .. cpp:var:: TensorMetadata outputs (repeated) + //@@ + //@@ The model's outputs. + //@@ + repeated TensorMetadata outputs = 5; +} + +//@@ +//@@.. cpp:var:: message InferParameter +//@@ +//@@ An inference parameter value. +//@@ +message InferParameter +{ + //@@ .. cpp:var:: oneof parameter_choice + //@@ + //@@ The parameter value can be a string, an int64 or + //@@ a boolean + //@@ + oneof parameter_choice + { + //@@ .. cpp:var:: bool bool_param + //@@ + //@@ A boolean parameter value. + //@@ + bool bool_param = 1; + + //@@ .. cpp:var:: int64 int64_param + //@@ + //@@ An int64 parameter value. + //@@ + int64 int64_param = 2; + + //@@ .. cpp:var:: string string_param + //@@ + //@@ A string parameter value. + //@@ + string string_param = 3; + } +} + +//@@ +//@@.. cpp:var:: message InferTensorContents +//@@ +//@@ The data contained in a tensor represented by the repeated type +//@@ that matches the tensor's data type. Protobuf oneof is not used +//@@ because oneofs cannot contain repeated fields. +//@@ +message InferTensorContents +{ + //@@ + //@@ .. cpp:var:: bool bool_contents (repeated) + //@@ + //@@ Representation for BOOL data type. The size must match what is + //@@ expected by the tensor's shape. The contents must be the flattened, + //@@ one-dimensional, row-major order of the tensor elements. + //@@ + repeated bool bool_contents = 1; + + //@@ + //@@ .. cpp:var:: int32 int_contents (repeated) + //@@ + //@@ Representation for INT8, INT16, and INT32 data types. The size + //@@ must match what is expected by the tensor's shape. The contents + //@@ must be the flattened, one-dimensional, row-major order of the + //@@ tensor elements. + //@@ + repeated int32 int_contents = 2; + + //@@ + //@@ .. cpp:var:: int64 int64_contents (repeated) + //@@ + //@@ Representation for INT64 data types. The size must match what + //@@ is expected by the tensor's shape. The contents must be the + //@@ flattened, one-dimensional, row-major order of the tensor elements. + //@@ + repeated int64 int64_contents = 3; + + //@@ + //@@ .. cpp:var:: uint32 uint_contents (repeated) + //@@ + //@@ Representation for UINT8, UINT16, and UINT32 data types. The size + //@@ must match what is expected by the tensor's shape. The contents + //@@ must be the flattened, one-dimensional, row-major order of the + //@@ tensor elements. + //@@ + repeated uint32 uint_contents = 4; + + //@@ + //@@ .. cpp:var:: uint64 uint64_contents (repeated) + //@@ + //@@ Representation for UINT64 data types. The size must match what + //@@ is expected by the tensor's shape. The contents must be the + //@@ flattened, one-dimensional, row-major order of the tensor elements. + //@@ + repeated uint64 uint64_contents = 5; + + //@@ + //@@ .. cpp:var:: float fp32_contents (repeated) + //@@ + //@@ Representation for FP32 data type. The size must match what is + //@@ expected by the tensor's shape. The contents must be the flattened, + //@@ one-dimensional, row-major order of the tensor elements. + //@@ + repeated float fp32_contents = 6; + + //@@ + //@@ .. cpp:var:: double fp64_contents (repeated) + //@@ + //@@ Representation for FP64 data type. The size must match what is + //@@ expected by the tensor's shape. The contents must be the flattened, + //@@ one-dimensional, row-major order of the tensor elements. + //@@ + repeated double fp64_contents = 7; + + //@@ + //@@ .. cpp:var:: bytes bytes_contents (repeated) + //@@ + //@@ Representation for BYTES data type. The size must match what is + //@@ expected by the tensor's shape. The contents must be the flattened, + //@@ one-dimensional, row-major order of the tensor elements. + //@@ + repeated bytes bytes_contents = 8; +} + +//@@ +//@@.. cpp:var:: message ModelInferRequest +//@@ +//@@ Request message for ModelInfer. +//@@ +message ModelInferRequest +{ + //@@ + //@@ .. cpp:var:: message InferInputTensor + //@@ + //@@ An input tensor for an inference request. + //@@ + message InferInputTensor + { + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The tensor name. + //@@ + string name = 1; + + //@@ + //@@ .. cpp:var:: string datatype + //@@ + //@@ The tensor data type. + //@@ + string datatype = 2; + + //@@ + //@@ .. cpp:var:: int64 shape (repeated) + //@@ + //@@ The tensor shape. + //@@ + repeated int64 shape = 3; + + //@@ .. cpp:var:: map parameters + //@@ + //@@ Optional inference input tensor parameters. + //@@ + map parameters = 4; + + //@@ .. cpp:var:: InferTensorContents contents + //@@ + //@@ The tensor contents using a data-type format. This field + //@@ must not be specified if tensor contents are being specified + //@@ in ModelInferRequest.raw_input_contents. + //@@ + InferTensorContents contents = 5; + } + + //@@ + //@@ .. cpp:var:: message InferRequestedOutputTensor + //@@ + //@@ An output tensor requested for an inference request. + //@@ + message InferRequestedOutputTensor + { + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The tensor name. + //@@ + string name = 1; + + //@@ .. cpp:var:: map parameters + //@@ + //@@ Optional requested output tensor parameters. + //@@ + map parameters = 2; + } + + //@@ .. cpp:var:: string model_name + //@@ + //@@ The name of the model to use for inferencing. + //@@ + string model_name = 1; + + //@@ .. cpp:var:: string model_version + //@@ + //@@ The version of the model to use for inference. If not + //@@ given the latest/most-recent version of the model is used. + //@@ + string model_version = 2; + + //@@ .. cpp:var:: string id + //@@ + //@@ Optional identifier for the request. If specified will be + //@@ returned in the response. + //@@ + string id = 3; + + //@@ .. cpp:var:: map parameters + //@@ + //@@ Optional inference parameters. + //@@ + map parameters = 4; + + //@@ + //@@ .. cpp:var:: InferInputTensor inputs (repeated) + //@@ + //@@ The input tensors for the inference. + //@@ + repeated InferInputTensor inputs = 5; + + //@@ + //@@ .. cpp:var:: InferRequestedOutputTensor outputs (repeated) + //@@ + //@@ The requested output tensors for the inference. Optional, if not + //@@ specified all outputs specified in the model config will be + //@@ returned. + //@@ + repeated InferRequestedOutputTensor outputs = 6; + + //@@ + //@@ .. cpp:var:: bytes raw_input_contents + //@@ + //@@ The data contained in an input tensor can be represented in + //@@ "raw" bytes form or in the repeated type that matches the + //@@ tensor's data type. Using the "raw" bytes form will + //@@ typically allow higher performance due to the way protobuf + //@@ allocation and reuse interacts with GRPC. For example, see + //@@ https://github.com/grpc/grpc/issues/23231. + //@@ + //@@ To use the raw representation 'raw_input_contents' must be + //@@ initialized with data for each tensor in the same order as + //@@ 'inputs'. For each tensor, the size of this content must + //@@ match what is expected by the tensor's shape and data + //@@ type. The raw data must be the flattened, one-dimensional, + //@@ row-major order of the tensor elements without any stride + //@@ or padding between the elements. Note that the FP16 and BF16 data + //@@ types must be represented as raw content as there is no + //@@ specific data type for a 16-bit float type. + //@@ + //@@ If this field is specified then InferInputTensor::contents + //@@ must not be specified for any input tensor. + //@@ + repeated bytes raw_input_contents = 7; +} + +//@@ +//@@.. cpp:var:: message ModelInferResponse +//@@ +//@@ Response message for ModelInfer. +//@@ +message ModelInferResponse +{ + //@@ + //@@ .. cpp:var:: message InferOutputTensor + //@@ + //@@ An output tensor returned for an inference request. + //@@ + message InferOutputTensor + { + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The tensor name. + //@@ + string name = 1; + + //@@ + //@@ .. cpp:var:: string datatype + //@@ + //@@ The tensor data type. + //@@ + string datatype = 2; + + //@@ + //@@ .. cpp:var:: int64 shape (repeated) + //@@ + //@@ The tensor shape. + //@@ + repeated int64 shape = 3; + + //@@ .. cpp:var:: map parameters + //@@ + //@@ Optional output tensor parameters. + //@@ + map parameters = 4; + + //@@ .. cpp:var:: InferTensorContents contents + //@@ + //@@ The tensor contents using a data-type format. This field + //@@ must not be specified if tensor contents are being specified + //@@ in ModelInferResponse.raw_output_contents. + //@@ + InferTensorContents contents = 5; + } + + //@@ .. cpp:var:: string model_name + //@@ + //@@ The name of the model used for inference. + //@@ + string model_name = 1; + + //@@ .. cpp:var:: string model_version + //@@ + //@@ The version of the model used for inference. + //@@ + string model_version = 2; + + //@@ .. cpp:var:: string id + //@@ + //@@ The id of the inference request if one was specified. + //@@ + string id = 3; + + //@@ .. cpp:var:: map parameters + //@@ + //@@ Optional inference response parameters. + //@@ + map parameters = 4; + + //@@ + //@@ .. cpp:var:: InferOutputTensor outputs (repeated) + //@@ + //@@ The output tensors holding inference results. + //@@ + repeated InferOutputTensor outputs = 5; + + //@@ + //@@ .. cpp:var:: bytes raw_output_contents + //@@ + //@@ The data contained in an output tensor can be represented in + //@@ "raw" bytes form or in the repeated type that matches the + //@@ tensor's data type. Using the "raw" bytes form will + //@@ typically allow higher performance due to the way protobuf + //@@ allocation and reuse interacts with GRPC. For example, see + //@@ https://github.com/grpc/grpc/issues/23231. + //@@ + //@@ To use the raw representation 'raw_output_contents' must be + //@@ initialized with data for each tensor in the same order as + //@@ 'outputs'. For each tensor, the size of this content must + //@@ match what is expected by the tensor's shape and data + //@@ type. The raw data must be the flattened, one-dimensional, + //@@ row-major order of the tensor elements without any stride + //@@ or padding between the elements. Note that the FP16 and BF16 data + //@@ types must be represented as raw content as there is no + //@@ specific data type for a 16-bit float type. + //@@ + //@@ If this field is specified then InferOutputTensor::contents + //@@ must not be specified for any output tensor. + //@@ + repeated bytes raw_output_contents = 6; +} + +//@@ +//@@.. cpp:var:: message ModelStreamInferResponse +//@@ +//@@ Response message for ModelStreamInfer. +//@@ +message ModelStreamInferResponse +{ + //@@ + //@@ .. cpp:var:: string error_message + //@@ + //@@ The message describing the error. The empty message + //@@ indicates the inference was successful without errors. + //@@ + string error_message = 1; + + //@@ + //@@ .. cpp:var:: ModelInferResponse infer_response + //@@ + //@@ Holds the results of the request. + //@@ + ModelInferResponse infer_response = 2; +} + +//@@ +//@@.. cpp:var:: message ModelConfigRequest +//@@ +//@@ Request message for ModelConfig. +//@@ +message ModelConfigRequest +{ + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the model. + //@@ + string name = 1; + + //@@ .. cpp:var:: string version + //@@ + //@@ The version of the model. If not given the model version + //@@ is selected automatically based on the version policy. + //@@ + string version = 2; +} + +//@@ +//@@.. cpp:var:: message ModelConfigResponse +//@@ +//@@ Response message for ModelConfig. +//@@ +message ModelConfigResponse +{ + //@@ + //@@ .. cpp:var:: ModelConfig config + //@@ + //@@ The model configuration. + //@@ + ModelConfig config = 1; +} + +//@@ +//@@.. cpp:var:: message ModelStatisticsRequest +//@@ +//@@ Request message for ModelStatistics. +//@@ +message ModelStatisticsRequest +{ + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the model. If not given returns statistics for + //@@ all models. + //@@ + string name = 1; + + //@@ .. cpp:var:: string version + //@@ + //@@ The version of the model. If not given returns statistics for + //@@ all model versions. + //@@ + string version = 2; +} + + +//@@ +//@@.. cpp:var:: message StatisticDuration +//@@ +//@@ Statistic recording a cumulative duration metric. +//@@ +message StatisticDuration +{ + //@@ .. cpp:var:: uint64 count + //@@ + //@@ Cumulative number of times this metric occurred. + //@@ + uint64 count = 1; + + //@@ .. cpp:var:: uint64 total_time_ns + //@@ + //@@ Total collected duration of this metric in nanoseconds. + //@@ + uint64 ns = 2; +} + +//@@ +//@@.. cpp:var:: message InferStatistics +//@@ +//@@ Inference statistics. +//@@ +message InferStatistics +{ + //@@ .. cpp:var:: StatisticDuration success + //@@ + //@@ Cumulative count and duration for successful inference + //@@ request. The "success" count and cumulative duration includes + //@@ cache hits. + //@@ + StatisticDuration success = 1; + + //@@ .. cpp:var:: StatisticDuration fail + //@@ + //@@ Cumulative count and duration for failed inference + //@@ request. + //@@ + StatisticDuration fail = 2; + + //@@ .. cpp:var:: StatisticDuration queue + //@@ + //@@ The count and cumulative duration that inference requests wait in + //@@ scheduling or other queues. The "queue" count and cumulative + //@@ duration includes cache hits. + //@@ + StatisticDuration queue = 3; + + //@@ .. cpp:var:: StatisticDuration compute_input + //@@ + //@@ The count and cumulative duration to prepare input tensor data as + //@@ required by the model framework / backend. For example, this duration + //@@ should include the time to copy input tensor data to the GPU. + //@@ The "compute_input" count and cumulative duration do not account for + //@@ requests that were a cache hit. See the "cache_hit" field for more + //@@ info. + //@@ + StatisticDuration compute_input = 4; + + //@@ .. cpp:var:: StatisticDuration compute_infer + //@@ + //@@ The count and cumulative duration to execute the model. + //@@ The "compute_infer" count and cumulative duration do not account for + //@@ requests that were a cache hit. See the "cache_hit" field for more + //@@ info. + //@@ + StatisticDuration compute_infer = 5; + + //@@ .. cpp:var:: StatisticDuration compute_output + //@@ + //@@ The count and cumulative duration to extract output tensor data + //@@ produced by the model framework / backend. For example, this duration + //@@ should include the time to copy output tensor data from the GPU. + //@@ The "compute_output" count and cumulative duration do not account for + //@@ requests that were a cache hit. See the "cache_hit" field for more + //@@ info. + //@@ + StatisticDuration compute_output = 6; + + //@@ .. cpp:var:: StatisticDuration cache_hit + //@@ + //@@ The count of response cache hits and cumulative duration to lookup + //@@ and extract output tensor data from the Response Cache on a cache + //@@ hit. For example, this duration should include the time to copy + //@@ output tensor data from the Response Cache to the response object. + //@@ On cache hits, triton does not need to go to the model/backend + //@@ for the output tensor data, so the "compute_input", "compute_infer", + //@@ and "compute_output" fields are not updated. Assuming the response + //@@ cache is enabled for a given model, a cache hit occurs for a + //@@ request to that model when the request metadata (model name, + //@@ model version, model inputs) hashes to an existing entry in the + //@@ cache. On a cache miss, the request hash and response output tensor + //@@ data is added to the cache. See response cache docs for more info: + //@@ https://github.com/triton-inference-server/server/blob/main/docs/response_cache.md + //@@ + StatisticDuration cache_hit = 7; + + //@@ .. cpp:var:: StatisticDuration cache_miss + //@@ + //@@ The count of response cache misses and cumulative duration to lookup + //@@ and insert output tensor data from the computed response to the cache. + //@@ For example, this duration should include the time to copy + //@@ output tensor data from the response object to the Response Cache. + //@@ Assuming the response cache is enabled for a given model, a cache + //@@ miss occurs for a request to that model when the request metadata + //@@ does NOT hash to an existing entry in the cache. See the response + //@@ cache docs for more info: + //@@ https://github.com/triton-inference-server/server/blob/main/docs/response_cache.md + //@@ + StatisticDuration cache_miss = 8; +} + +//@@ +//@@.. cpp:var:: message InferBatchStatistics +//@@ +//@@ Inference batch statistics. +//@@ +message InferBatchStatistics +{ + //@@ .. cpp:var:: uint64 batch_size + //@@ + //@@ The size of the batch. + //@@ + uint64 batch_size = 1; + + //@@ .. cpp:var:: StatisticDuration compute_input + //@@ + //@@ The count and cumulative duration to prepare input tensor data as + //@@ required by the model framework / backend with the given batch size. + //@@ For example, this duration should include the time to copy input + //@@ tensor data to the GPU. + //@@ + StatisticDuration compute_input = 2; + + //@@ .. cpp:var:: StatisticDuration compute_infer + //@@ + //@@ The count and cumulative duration to execute the model with the given + //@@ batch size. + //@@ + StatisticDuration compute_infer = 3; + + //@@ .. cpp:var:: StatisticDuration compute_output + //@@ + //@@ The count and cumulative duration to extract output tensor data + //@@ produced by the model framework / backend with the given batch size. + //@@ For example, this duration should include the time to copy output + //@@ tensor data from the GPU. + //@@ + StatisticDuration compute_output = 4; +} + +//@@ +//@@.. cpp:var:: message ModelStatistics +//@@ +//@@ Statistics for a specific model and version. +//@@ +message ModelStatistics +{ + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the model. If not given returns statistics for all + //@@ + string name = 1; + + //@@ .. cpp:var:: string version + //@@ + //@@ The version of the model. + //@@ + string version = 2; + + //@@ .. cpp:var:: uint64 last_inference + //@@ + //@@ The timestamp of the last inference request made for this model, + //@@ as milliseconds since the epoch. + //@@ + uint64 last_inference = 3; + + //@@ .. cpp:var:: uint64 last_inference + //@@ + //@@ The cumulative count of successful inference requests made for this + //@@ model. Each inference in a batched request is counted as an + //@@ individual inference. For example, if a client sends a single + //@@ inference request with batch size 64, "inference_count" will be + //@@ incremented by 64. Similarly, if a clients sends 64 individual + //@@ requests each with batch size 1, "inference_count" will be + //@@ incremented by 64. The "inference_count" value DOES NOT include + //@@ cache hits. + //@@ + uint64 inference_count = 4; + + //@@ .. cpp:var:: uint64 last_inference + //@@ + //@@ The cumulative count of the number of successful inference executions + //@@ performed for the model. When dynamic batching is enabled, a single + //@@ model execution can perform inferencing for more than one inference + //@@ request. For example, if a clients sends 64 individual requests each + //@@ with batch size 1 and the dynamic batcher batches them into a single + //@@ large batch for model execution then "execution_count" will be + //@@ incremented by 1. If, on the other hand, the dynamic batcher is not + //@@ enabled for that each of the 64 individual requests is executed + //@@ independently, then "execution_count" will be incremented by 64. + //@@ The "execution_count" value DOES NOT include cache hits. + //@@ + uint64 execution_count = 5; + + //@@ .. cpp:var:: InferStatistics inference_stats + //@@ + //@@ The aggregate statistics for the model/version. + //@@ + InferStatistics inference_stats = 6; + + //@@ .. cpp:var:: InferBatchStatistics batch_stats (repeated) + //@@ + //@@ The aggregate statistics for each different batch size that is + //@@ executed in the model. The batch statistics indicate how many actual + //@@ model executions were performed and show differences due to different + //@@ batch size (for example, larger batches typically take longer to + //@@ compute). + //@@ + repeated InferBatchStatistics batch_stats = 7; +} + +//@@ +//@@.. cpp:var:: message ModelStatisticsResponse +//@@ +//@@ Response message for ModelStatistics. +//@@ +message ModelStatisticsResponse +{ + //@@ .. cpp:var:: ModelStatistics model_stats (repeated) + //@@ + //@@ Statistics for each requested model. + //@@ + repeated ModelStatistics model_stats = 1; +} + +//@@ +//@@.. cpp:var:: message ModelRepositoryParameter +//@@ +//@@ An model repository parameter value. +//@@ +message ModelRepositoryParameter +{ + //@@ .. cpp:var:: oneof parameter_choice + //@@ + //@@ The parameter value can be a string, an int64 or + //@@ a boolean + //@@ + oneof parameter_choice + { + //@@ .. cpp:var:: bool bool_param + //@@ + //@@ A boolean parameter value. + //@@ + bool bool_param = 1; + + //@@ .. cpp:var:: int64 int64_param + //@@ + //@@ An int64 parameter value. + //@@ + int64 int64_param = 2; + + //@@ .. cpp:var:: string string_param + //@@ + //@@ A string parameter value. + //@@ + string string_param = 3; + + //@@ .. cpp:var:: bytes bytes_param + //@@ + //@@ A bytes parameter value. + //@@ + bytes bytes_param = 4; + } +} + +//@@ +//@@.. cpp:var:: message RepositoryIndexRequest +//@@ +//@@ Request message for RepositoryIndex. +//@@ +message RepositoryIndexRequest +{ + //@@ .. cpp:var:: string repository_name + //@@ + //@@ The name of the repository. If empty the index is returned + //@@ for all repositories. + //@@ + string repository_name = 1; + + //@@ .. cpp:var:: bool ready + //@@ + //@@ If true returned only models currently ready for inferencing. + //@@ + bool ready = 2; +} + +//@@ +//@@.. cpp:var:: message RepositoryIndexResponse +//@@ +//@@ Response message for RepositoryIndex. +//@@ +message RepositoryIndexResponse +{ + //@@ + //@@ .. cpp:var:: message ModelIndex + //@@ + //@@ Index entry for a model. + //@@ + message ModelIndex + { + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the model. + //@@ + string name = 1; + + //@@ .. cpp:var:: string version + //@@ + //@@ The version of the model. + //@@ + string version = 2; + + //@@ + //@@ .. cpp:var:: string state + //@@ + //@@ The state of the model. + //@@ + string state = 3; + + //@@ + //@@ .. cpp:var:: string reason + //@@ + //@@ The reason, if any, that the model is in the given state. + //@@ + string reason = 4; + } + + //@@ + //@@ .. cpp:var:: ModelIndex models (repeated) + //@@ + //@@ An index entry for each model. + //@@ + repeated ModelIndex models = 1; +} + +//@@ +//@@.. cpp:var:: message RepositoryModelLoadRequest +//@@ +//@@ Request message for RepositoryModelLoad. +//@@ +message RepositoryModelLoadRequest +{ + //@@ .. cpp:var:: string repository_name + //@@ + //@@ The name of the repository to load from. If empty the model + //@@ is loaded from any repository. + //@@ + string repository_name = 1; + + //@@ .. cpp:var:: string repository_name + //@@ + //@@ The name of the model to load, or reload. + //@@ + string model_name = 2; + + //@@ .. cpp:var:: map parameters + //@@ + //@@ Optional model repository request parameters. + //@@ + map parameters = 3; +} + +//@@ +//@@.. cpp:var:: message RepositoryModelLoadResponse +//@@ +//@@ Response message for RepositoryModelLoad. +//@@ +message RepositoryModelLoadResponse {} + +//@@ +//@@.. cpp:var:: message RepositoryModelUnloadRequest +//@@ +//@@ Request message for RepositoryModelUnload. +//@@ +message RepositoryModelUnloadRequest +{ + //@@ .. cpp:var:: string repository_name + //@@ + //@@ The name of the repository from which the model was originally + //@@ loaded. If empty the repository is not considered. + //@@ + string repository_name = 1; + + //@@ .. cpp:var:: string repository_name + //@@ + //@@ The name of the model to unload. + //@@ + string model_name = 2; + + //@@ .. cpp:var:: map parameters + //@@ + //@@ Optional model repository request parameters. + //@@ + map parameters = 3; +} + +//@@ +//@@.. cpp:var:: message RepositoryModelUnloadResponse +//@@ +//@@ Response message for RepositoryModelUnload. +//@@ +message RepositoryModelUnloadResponse {} + +//@@ +//@@.. cpp:var:: message SystemSharedMemoryStatusRequest +//@@ +//@@ Request message for SystemSharedMemoryStatus. +//@@ +message SystemSharedMemoryStatusRequest +{ + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the region to get status for. If empty the + //@@ status is returned for all registered regions. + //@@ + string name = 1; +} + +//@@ +//@@.. cpp:var:: message SystemSharedMemoryStatusResponse +//@@ +//@@ Response message for SystemSharedMemoryStatus. +//@@ +message SystemSharedMemoryStatusResponse +{ + //@@ + //@@ .. cpp:var:: message RegionStatus + //@@ + //@@ Status for a shared memory region. + //@@ + message RegionStatus + { + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The name for the shared memory region. + //@@ + string name = 1; + + //@@ .. cpp:var:: string shared_memory_key + //@@ + //@@ The key of the underlying memory object that contains the + //@@ shared memory region. + //@@ + string key = 2; + + //@@ .. cpp:var:: uint64 offset + //@@ + //@@ Offset, in bytes, within the underlying memory object to + //@@ the start of the shared memory region. + //@@ + uint64 offset = 3; + + //@@ .. cpp:var:: uint64 byte_size + //@@ + //@@ Size of the shared memory region, in bytes. + //@@ + uint64 byte_size = 4; + } + + //@@ + //@@ .. cpp:var:: map regions + //@@ + //@@ Status for each of the registered regions, indexed by + //@@ region name. + //@@ + map regions = 1; +} + +//@@ +//@@.. cpp:var:: message SystemSharedMemoryRegisterRequest +//@@ +//@@ Request message for SystemSharedMemoryRegister. +//@@ +message SystemSharedMemoryRegisterRequest +{ + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the region to register. + //@@ + string name = 1; + + //@@ .. cpp:var:: string shared_memory_key + //@@ + //@@ The key of the underlying memory object that contains the + //@@ shared memory region. + //@@ + string key = 2; + + //@@ .. cpp:var:: uint64 offset + //@@ + //@@ Offset, in bytes, within the underlying memory object to + //@@ the start of the shared memory region. + //@@ + uint64 offset = 3; + + //@@ .. cpp:var:: uint64 byte_size + //@@ + //@@ Size of the shared memory region, in bytes. + //@@ + uint64 byte_size = 4; +} + +//@@ +//@@.. cpp:var:: message SystemSharedMemoryRegisterResponse +//@@ +//@@ Response message for SystemSharedMemoryRegister. +//@@ +message SystemSharedMemoryRegisterResponse {} + +//@@ +//@@.. cpp:var:: message SystemSharedMemoryUnregisterRequest +//@@ +//@@ Request message for SystemSharedMemoryUnregister. +//@@ +message SystemSharedMemoryUnregisterRequest +{ + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the system region to unregister. If empty + //@@ all system shared-memory regions are unregistered. + //@@ + string name = 1; +} + +//@@ +//@@.. cpp:var:: message SystemSharedMemoryUnregisterResponse +//@@ +//@@ Response message for SystemSharedMemoryUnregister. +//@@ +message SystemSharedMemoryUnregisterResponse {} + +//@@ +//@@.. cpp:var:: message CudaSharedMemoryStatusRequest +//@@ +//@@ Request message for CudaSharedMemoryStatus. +//@@ +message CudaSharedMemoryStatusRequest +{ + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the region to get status for. If empty the + //@@ status is returned for all registered regions. + //@@ + string name = 1; +} + +//@@ +//@@.. cpp:var:: message CudaSharedMemoryStatusResponse +//@@ +//@@ Response message for CudaSharedMemoryStatus. +//@@ +message CudaSharedMemoryStatusResponse +{ + //@@ + //@@ .. cpp:var:: message RegionStatus + //@@ + //@@ Status for a shared memory region. + //@@ + message RegionStatus + { + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The name for the shared memory region. + //@@ + string name = 1; + + //@@ .. cpp:var:: uin64 device_id + //@@ + //@@ The GPU device ID where the cudaIPC handle was created. + //@@ + uint64 device_id = 2; + + //@@ .. cpp:var:: uint64 byte_size + //@@ + //@@ Size of the shared memory region, in bytes. + //@@ + uint64 byte_size = 3; + } + + //@@ + //@@ .. cpp:var:: map regions + //@@ + //@@ Status for each of the registered regions, indexed by + //@@ region name. + //@@ + map regions = 1; +} + +//@@ +//@@.. cpp:var:: message CudaSharedMemoryRegisterRequest +//@@ +//@@ Request message for CudaSharedMemoryRegister. +//@@ +message CudaSharedMemoryRegisterRequest +{ + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the region to register. + //@@ + string name = 1; + + //@@ .. cpp:var:: bytes raw_handle + //@@ + //@@ The raw serialized cudaIPC handle. + //@@ + bytes raw_handle = 2; + + //@@ .. cpp:var:: int64 device_id + //@@ + //@@ The GPU device ID on which the cudaIPC handle was created. + //@@ + int64 device_id = 3; + + //@@ .. cpp:var:: uint64 byte_size + //@@ + //@@ Size of the shared memory block, in bytes. + //@@ + uint64 byte_size = 4; +} + +//@@ +//@@.. cpp:var:: message CudaSharedMemoryRegisterResponse +//@@ +//@@ Response message for CudaSharedMemoryRegister. +//@@ +message CudaSharedMemoryRegisterResponse {} + +//@@ +//@@.. cpp:var:: message CudaSharedMemoryUnregisterRequest +//@@ +//@@ Request message for CudaSharedMemoryUnregister. +//@@ +message CudaSharedMemoryUnregisterRequest +{ + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the cuda region to unregister. If empty + //@@ all cuda shared-memory regions are unregistered. + //@@ + string name = 1; +} + +//@@ +//@@.. cpp:var:: message CudaSharedMemoryUnregisterResponse +//@@ +//@@ Response message for CudaSharedMemoryUnregister. +//@@ +message CudaSharedMemoryUnregisterResponse {} + +//@@ +//@@.. cpp:var:: message TraceSettingRequest +//@@ +//@@ Request message for TraceSetting. +//@@ +message TraceSettingRequest +{ + //@@ + //@@ .. cpp:var:: message SettingValue + //@@ + //@@ The values to be associated with a trace setting. + //@@ If no value is provided, the setting will be clear and + //@@ the global setting value will be used. + //@@ + message SettingValue + { + //@@ + //@@ .. cpp:var:: string value (repeated) + //@@ + //@@ The value. + //@@ + repeated string value = 1; + } + + //@@ .. cpp:var:: map settings + //@@ + //@@ The new setting values to be updated, + //@@ settings that are not specified will remain unchanged. + //@@ + map settings = 1; + + //@@ + //@@ .. cpp:var:: string model_name + //@@ + //@@ The name of the model to apply the new trace settings. + //@@ If not given, the new settings will be applied globally. + //@@ + string model_name = 2; +} + +//@@ +//@@.. cpp:var:: message TraceSettingResponse +//@@ +//@@ Response message for TraceSetting. +//@@ +message TraceSettingResponse +{ + //@@ + //@@ .. cpp:var:: message SettingValue + //@@ + //@@ The values to be associated with a trace setting. + //@@ + message SettingValue + { + //@@ + //@@ .. cpp:var:: string value (repeated) + //@@ + //@@ The value. + //@@ + repeated string value = 1; + } + + //@@ .. cpp:var:: map settings + //@@ + //@@ The current trace settings, including any changes specified + //@@ by TraceSettingRequest. + //@@ + map settings = 1; +} + +//@@ +//@@.. cpp:var:: message LogSettingsRequest +//@@ +//@@ Request message for LogSettings. +//@@ +message LogSettingsRequest +{ + message SettingValue + { + oneof parameter_choice + { + //@@ .. cpp:var:: bool bool_param + //@@ + //@@ A boolean parameter value. + //@@ + bool bool_param = 1; + + //@@ .. cpp:var:: uint32 uint32_param + //@@ + //@@ An uint32 parameter value. + //@@ + uint32 uint32_param = 2; + + //@@ .. cpp:var:: string string_param + //@@ + //@@ A string parameter value. + //@@ + string string_param = 3; + } + } + //@@ .. cpp:var:: map settings + //@@ + //@@ The current log settings. + //@@ + map settings = 1; +} + +//@@ +//@@.. cpp:var:: message LogSettingsResponse +//@@ +//@@ Response message for LogSettings. +//@@ +message LogSettingsResponse +{ + message SettingValue + { + oneof parameter_choice + { + //@@ .. cpp:var:: bool bool_param + //@@ + //@@ A boolean parameter value. + //@@ + bool bool_param = 1; + + //@@ .. cpp:var:: uint32 uint32_param + //@@ + //@@ An int32 parameter value. + //@@ + uint32 uint32_param = 2; + + //@@ .. cpp:var:: string string_param + //@@ + //@@ A string parameter value. + //@@ + string string_param = 3; + } + } + //@@ .. cpp:var:: map settings + //@@ + //@@ The current log settings. + //@@ + map settings = 1; +} + diff --git a/3rdparty/common-r22.12/protobuf/model_config.proto b/3rdparty/common-r22.12/protobuf/model_config.proto new file mode 100644 index 0000000000000000000000000000000000000000..b3c5a7b25c792e88e5cc8564d47dd6c048674e46 --- /dev/null +++ b/3rdparty/common-r22.12/protobuf/model_config.proto @@ -0,0 +1,1981 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Copyright (c) 2018, TensorFlow Authors. All rights reserved. + +syntax = "proto3"; + +package inference; + +//@@.. cpp:namespace:: inference + +//@@ +//@@.. cpp:enum:: DataType +//@@ +//@@ Data types supported for input and output tensors. +//@@ +enum DataType { + //@@ .. cpp:enumerator:: DataType::INVALID = 0 + TYPE_INVALID = 0; + + //@@ .. cpp:enumerator:: DataType::BOOL = 1 + TYPE_BOOL = 1; + + //@@ .. cpp:enumerator:: DataType::UINT8 = 2 + TYPE_UINT8 = 2; + //@@ .. cpp:enumerator:: DataType::UINT16 = 3 + TYPE_UINT16 = 3; + //@@ .. cpp:enumerator:: DataType::UINT32 = 4 + TYPE_UINT32 = 4; + //@@ .. cpp:enumerator:: DataType::UINT64 = 5 + TYPE_UINT64 = 5; + + //@@ .. cpp:enumerator:: DataType::INT8 = 6 + TYPE_INT8 = 6; + //@@ .. cpp:enumerator:: DataType::INT16 = 7 + TYPE_INT16 = 7; + //@@ .. cpp:enumerator:: DataType::INT32 = 8 + TYPE_INT32 = 8; + //@@ .. cpp:enumerator:: DataType::INT64 = 9 + TYPE_INT64 = 9; + + //@@ .. cpp:enumerator:: DataType::FP16 = 10 + TYPE_FP16 = 10; + //@@ .. cpp:enumerator:: DataType::FP32 = 11 + TYPE_FP32 = 11; + //@@ .. cpp:enumerator:: DataType::FP64 = 12 + TYPE_FP64 = 12; + + //@@ .. cpp:enumerator:: DataType::STRING = 13 + TYPE_STRING = 13; + + //@@ .. cpp:enumerator:: DataType::BF16 = 14 + TYPE_BF16 = 14; +} + +//@@ +//@@ .. cpp:var:: message ModelRateLimiter +//@@ +//@@ The specifications required by the rate limiter to properly +//@@ schedule the inference requests across the different models +//@@ and their instances. +//@@ +message ModelRateLimiter +{ + //@@ .. cpp:var:: message Resource + //@@ + //@@ The resource property. + //@@ + message Resource + { + //@@ .. cpp:var:: string name + //@@ + //@@ The name associated with the resource. + //@@ + string name = 1; + + //@@ .. cpp:var:: bool global + //@@ + //@@ Whether or not the resource is global. If true then the resource + //@@ is assumed to be shared among the devices otherwise specified + //@@ count of the resource is assumed for each device associated + //@@ with the instance. + //@@ + bool global = 2; + + //@@ .. cpp:var:: uint32 count + //@@ + //@@ The number of resources required for the execution of the model + //@@ instance. + //@@ + uint32 count = 3; + } + + //@@ .. cpp:var:: Resource resources (repeated) + //@@ + //@@ The resources required to execute the request on a model instance. + //@@ Resources are just names with a corresponding count. The execution + //@@ of the instance will be blocked until the specificied resources are + //@@ available. By default an instance uses no rate-limiter resources. + //@@ + repeated Resource resources = 1; + + //@@ .. cpp:var:: uint32 priority + //@@ + //@@ The optional weighting value to be used for prioritizing across + //@@ instances. An instance with priority 2 will be given 1/2 the + //@@ number of scheduling chances as an instance_group with priority + //@@ 1. The default priority is 1. The priority of value 0 will be + //@@ treated as priority 1. + //@@ + uint32 priority = 2; +} + +//@@ +//@@.. cpp:var:: message ModelInstanceGroup +//@@ +//@@ A group of one or more instances of a model and resources made +//@@ available for those instances. +//@@ +message ModelInstanceGroup +{ + //@@ + //@@ .. cpp:enum:: Kind + //@@ + //@@ Kind of this instance group. + //@@ + enum Kind { + //@@ .. cpp:enumerator:: Kind::KIND_AUTO = 0 + //@@ + //@@ This instance group represents instances that can run on either + //@@ CPU or GPU. If all GPUs listed in 'gpus' are available then + //@@ instances will be created on GPU(s), otherwise instances will + //@@ be created on CPU. + //@@ + KIND_AUTO = 0; + + //@@ .. cpp:enumerator:: Kind::KIND_GPU = 1 + //@@ + //@@ This instance group represents instances that must run on the + //@@ GPU. + //@@ + KIND_GPU = 1; + + //@@ .. cpp:enumerator:: Kind::KIND_CPU = 2 + //@@ + //@@ This instance group represents instances that must run on the + //@@ CPU. + //@@ + KIND_CPU = 2; + + //@@ .. cpp:enumerator:: Kind::KIND_MODEL = 3 + //@@ + //@@ This instance group represents instances that should run on the + //@@ CPU and/or GPU(s) as specified by the model or backend itself. + //@@ The inference server will not override the model/backend + //@@ settings. + //@@ + KIND_MODEL = 3; + } + + //@@ + //@@ .. cpp:var:: message SecondaryDevice + //@@ + //@@ A secondary device required for a model instance. + //@@ + message SecondaryDevice + { + //@@ + //@@ .. cpp:enum:: SecondaryDeviceKind + //@@ + //@@ The kind of the secondary device. + //@@ + enum SecondaryDeviceKind { + //@@ .. cpp:enumerator:: SecondaryDeviceKind::KIND_NVDLA = 0 + //@@ + //@@ An NVDLA core. http://nvdla.org + //@@ Currently KIND_NVDLA is only supported by the TensorRT backend. + //@@ + KIND_NVDLA = 0; + } + + //@@ .. cpp:var:: SecondaryDeviceKind kind + //@@ + //@@ The secondary device kind. + //@@ + SecondaryDeviceKind kind = 1; + + //@@ .. cpp:var:: int64 device_id + //@@ + //@@ Identifier for the secondary device. + //@@ + int64 device_id = 2; + } + + //@@ .. cpp:var:: string name + //@@ + //@@ Optional name of this group of instances. If not specified the + //@@ name will be formed as _. The name of + //@@ individual instances will be further formed by a unique instance + //@@ number and GPU index: + //@@ + string name = 1; + + //@@ .. cpp:var:: Kind kind + //@@ + //@@ The kind of this instance group. Default is KIND_AUTO. If + //@@ KIND_AUTO or KIND_GPU then both 'count' and 'gpu' are valid and + //@@ may be specified. If KIND_CPU or KIND_MODEL only 'count' is valid + //@@ and 'gpu' cannot be specified. + //@@ + Kind kind = 4; + + //@@ .. cpp:var:: int32 count + //@@ + //@@ For a group assigned to GPU, the number of instances created for + //@@ each GPU listed in 'gpus'. For a group assigned to CPU the number + //@@ of instances created. Default is 1. + int32 count = 2; + + //@@ .. cpp:var:: ModelRateLimiter rate_limiter + //@@ + //@@ The rate limiter specific settings to be associated with this + //@@ instance group. Optional, if not specified no rate limiting + //@@ will be applied to this instance group. + //@@ + ModelRateLimiter rate_limiter = 6; + + //@@ .. cpp:var:: int32 gpus (repeated) + //@@ + //@@ GPU(s) where instances should be available. For each GPU listed, + //@@ 'count' instances of the model will be available. Setting 'gpus' + //@@ to empty (or not specifying at all) is eqivalent to listing all + //@@ available GPUs. + //@@ + repeated int32 gpus = 3; + + //@@ .. cpp:var:: SecondaryDevice secondary_devices (repeated) + //@@ + //@@ Secondary devices that are required by instances specified by this + //@@ instance group. Optional. + //@@ + repeated SecondaryDevice secondary_devices = 8; + + //@@ .. cpp:var:: string profile (repeated) + //@@ + //@@ For TensorRT models containing multiple optimization profile, this + //@@ parameter specifies a set of optimization profiles available to this + //@@ instance group. The inference server will choose the optimal profile + //@@ based on the shapes of the input tensors. This field should lie + //@@ between 0 and - 1 + //@@ and be specified only for TensorRT backend, otherwise an error will + //@@ be generated. If not specified, the server will select the first + //@@ optimization profile by default. + //@@ + repeated string profile = 5; + + //@@ .. cpp:var:: bool passive + //@@ + //@@ Whether the instances within this instance group will be accepting + //@@ inference requests from the scheduler. If true, the instances will + //@@ not be added to the scheduler. Default value is false. + //@@ + bool passive = 7; + + //@@ .. cpp:var:: string host_policy + //@@ + //@@ The host policy name that the instance to be associated with. + //@@ The default value is set to reflect the device kind of the instance, + //@@ for instance, KIND_CPU is "cpu", KIND_MODEL is "model" and + //@@ KIND_GPU is "gpu_". + //@@ + string host_policy = 9; +} + +//@@ +//@@.. cpp:var:: message ModelTensorReshape +//@@ +//@@ Reshape specification for input and output tensors. +//@@ +message ModelTensorReshape +{ + //@@ .. cpp:var:: int64 shape (repeated) + //@@ + //@@ The shape to use for reshaping. + //@@ + repeated int64 shape = 1; +} + +//@@ +//@@.. cpp:var:: message ModelInput +//@@ +//@@ An input required by the model. +//@@ +message ModelInput +{ + //@@ + //@@ .. cpp:enum:: Format + //@@ + //@@ The format for the input. + //@@ + enum Format { + //@@ .. cpp:enumerator:: Format::FORMAT_NONE = 0 + //@@ + //@@ The input has no specific format. This is the default. + //@@ + FORMAT_NONE = 0; + + //@@ .. cpp:enumerator:: Format::FORMAT_NHWC = 1 + //@@ + //@@ HWC image format. Tensors with this format require 3 dimensions + //@@ if the model does not support batching (max_batch_size = 0) or 4 + //@@ dimensions if the model does support batching (max_batch_size + //@@ >= 1). In either case the 'dims' below should only specify the + //@@ 3 non-batch dimensions (i.e. HWC or CHW). + //@@ + FORMAT_NHWC = 1; + + //@@ .. cpp:enumerator:: Format::FORMAT_NCHW = 2 + //@@ + //@@ CHW image format. Tensors with this format require 3 dimensions + //@@ if the model does not support batching (max_batch_size = 0) or 4 + //@@ dimensions if the model does support batching (max_batch_size + //@@ >= 1). In either case the 'dims' below should only specify the + //@@ 3 non-batch dimensions (i.e. HWC or CHW). + //@@ + FORMAT_NCHW = 2; + } + + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the input. + //@@ + string name = 1; + + //@@ .. cpp:var:: DataType data_type + //@@ + //@@ The data-type of the input. + //@@ + DataType data_type = 2; + + //@@ .. cpp:var:: Format format + //@@ + //@@ The format of the input. Optional. + //@@ + Format format = 3; + + //@@ .. cpp:var:: int64 dims (repeated) + //@@ + //@@ The dimensions/shape of the input tensor that must be provided + //@@ when invoking the inference API for this model. + //@@ + repeated int64 dims = 4; + + //@@ .. cpp:var:: ModelTensorReshape reshape + //@@ + //@@ The shape expected for this input by the backend. The input will + //@@ be reshaped to this before being presented to the backend. The + //@@ reshape must have the same number of elements as the input shape + //@@ specified by 'dims'. Optional. + //@@ + ModelTensorReshape reshape = 5; + + //@@ .. cpp:var:: bool is_shape_tensor + //@@ + //@@ Whether or not the input is a shape tensor to the model. This field + //@@ is currently supported only for the TensorRT model. An error will be + //@@ generated if this specification does not comply with underlying + //@@ model. + //@@ + bool is_shape_tensor = 6; + + //@@ .. cpp:var:: bool allow_ragged_batch + //@@ + //@@ Whether or not the input is allowed to be "ragged" in a dynamically + //@@ created batch. Default is false indicating that two requests will + //@@ only be batched if this tensor has the same shape in both requests. + //@@ True indicates that two requests can be batched even if this tensor + //@@ has a different shape in each request. + //@@ + bool allow_ragged_batch = 7; + + //@@ .. cpp:var:: bool optional + //@@ + //@@ Whether or not the input is optional for the model execution. + //@@ If true, the input is not required in the inference request. + //@@ Default value is false. + //@@ + bool optional = 8; +} + +//@@ +//@@.. cpp:var:: message ModelOutput +//@@ +//@@ An output produced by the model. +//@@ +message ModelOutput +{ + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the output. + //@@ + string name = 1; + + //@@ .. cpp:var:: DataType data_type + //@@ + //@@ The data-type of the output. + //@@ + DataType data_type = 2; + + //@@ .. cpp:var:: int64 dims (repeated) + //@@ + //@@ The dimensions/shape of the output tensor. + //@@ + repeated int64 dims = 3; + + //@@ .. cpp:var:: ModelTensorReshape reshape + //@@ + //@@ The shape produced for this output by the backend. The output will + //@@ be reshaped from this to the shape specifed in 'dims' before being + //@@ returned in the inference response. The reshape must have the same + //@@ number of elements as the output shape specified by 'dims'. Optional. + //@@ + ModelTensorReshape reshape = 5; + + //@@ .. cpp:var:: string label_filename + //@@ + //@@ The label file associated with this output. Should be specified only + //@@ for outputs that represent classifications. Optional. + //@@ + string label_filename = 4; + + + //@@ .. cpp:var:: bool is_shape_tensor + //@@ + //@@ Whether or not the output is a shape tensor to the model. This field + //@@ is currently supported only for the TensorRT model. An error will be + //@@ generated if this specification does not comply with underlying + //@@ model. + //@@ + bool is_shape_tensor = 6; +} + +//@@ .. cpp:var:: message BatchInput +//@@ +//@@ A batch input is an additional input that must be added by +//@@ the backend based on all the requests in a batch. +//@@ +message BatchInput +{ + //@@ + //@@ .. cpp:enum:: Kind + //@@ + //@@ The kind of the batch input. + //@@ + enum Kind { + //@@ .. cpp:enumerator:: Kind::BATCH_ELEMENT_COUNT = 0 + //@@ + //@@ The element count of the 'source_input' will be added as + //@@ input with shape [1]. + //@@ + BATCH_ELEMENT_COUNT = 0; + + //@@ .. cpp:enumerator:: Kind::BATCH_ACCUMULATED_ELEMENT_COUNT = 1 + //@@ + //@@ The accumulated element count of the 'source_input' will be + //@@ added as input with shape [1]. For example, if there is a + //@@ batch of two request, each with 2 elements, an input of value + //@@ 2 will be added to the first request, and an input of value + //@@ 4 will be added to the second request. + //@@ + BATCH_ACCUMULATED_ELEMENT_COUNT = 1; + + //@@ .. cpp:enumerator:: + //@@ Kind::BATCH_ACCUMULATED_ELEMENT_COUNT_WITH_ZERO = 2 + //@@ + //@@ The accumulated element count of the 'source_input' will be + //@@ added as input with shape [1], except for the first request + //@@ in the batch. For the first request in the batch, the input + //@@ will have shape [2] where the first element is value 0. + //@@ + BATCH_ACCUMULATED_ELEMENT_COUNT_WITH_ZERO = 2; + + //@@ .. cpp:enumerator:: Kind::BATCH_MAX_ELEMENT_COUNT_AS_SHAPE = 3 + //@@ + //@@ Among the requests in the batch, the max element count of the + //@@ 'source_input' will be added as input with shape + //@@ [max_element_count] for the first request in the batch. + //@@ For other requests, such input will be with shape [0]. + //@@ The data of the tensor will be uninitialized. + //@@ + BATCH_MAX_ELEMENT_COUNT_AS_SHAPE = 3; + + //@@ .. cpp:enumerator:: Kind::BATCH_ITEM_SHAPE = 4 + //@@ + //@@ Among the requests in the batch, the shape of the + //@@ 'source_input' will be added as input with shape + //@@ [batch_size, len(input_dim)]. For example, if one + //@@ batch-2 input with shape [3, 1] and batch-1 input + //@@ with shape [2, 2] are batched, the batch input will + //@@ have shape [3, 2] and value [ [3, 1], [3, 1], [2, 2]]. + //@@ + BATCH_ITEM_SHAPE = 4; + + //@@ .. cpp:enumerator:: Kind::BATCH_ITEM_SHAPE_FLATTEN = 5 + //@@ + //@@ Among the requests in the batch, the shape of the + //@@ 'source_input' will be added as input with single dimensional + //@@ shape [batch_size * len(input_dim)]. For example, if one + //@@ batch-2 input with shape [3, 1] and batch-1 input + //@@ with shape [2, 2] are batched, the batch input will + //@@ have shape [6] and value [3, 1, 3, 1, 2, 2]. + //@@ + BATCH_ITEM_SHAPE_FLATTEN = 5; + } + + //@@ .. cpp:var:: Kind kind + //@@ + //@@ The kind of this batch input. + //@@ + Kind kind = 1; + + //@@ .. cpp:var:: string target_name (repeated) + //@@ + //@@ The name of the model inputs that the backend will create + //@@ for this batch input. + //@@ + repeated string target_name = 2; + + //@@ .. cpp:var:: DataType data_type + //@@ + //@@ The input's datatype. The data type can be TYPE_INT32 or + //@@ TYPE_FP32. + //@@ + DataType data_type = 3; + + //@@ .. cpp:var:: string source_input (repeated) + //@@ + //@@ The backend derives the value for each batch input from one or + //@@ more other inputs. 'source_input' gives the names of those + //@@ inputs. + //@@ + repeated string source_input = 4; +} + +//@@.. cpp:var:: message BatchOutput +//@@ +//@@ A batch output is an output produced by the model that must be handled +//@@ differently by the backend based on all the requests in a batch. +//@@ +message BatchOutput +{ + //@@ + //@@ .. cpp:enum:: Kind + //@@ + //@@ The kind of the batch output. + //@@ + enum Kind { + //@@ .. cpp:enumerator:: Kind::BATCH_SCATTER_WITH_INPUT_SHAPE = 0 + //@@ + //@@ The output should be scattered according to the shape of + //@@ 'source_input'. The dynamic dimension of the output will + //@@ be set to the value of the same dimension in the input. + //@@ + BATCH_SCATTER_WITH_INPUT_SHAPE = 0; + } + + //@@ .. cpp:var:: string target_name (repeated) + //@@ + //@@ The name of the outputs to be produced by this batch output + //@@ specification. + //@@ + repeated string target_name = 1; + + //@@ .. cpp:var:: Kind kind + //@@ + //@@ The kind of this batch output. + //@@ + Kind kind = 2; + + //@@ .. cpp:var:: string source_input (repeated) + //@@ + //@@ The backend derives each batch output from one or more inputs. + //@@ 'source_input' gives the names of those inputs. + //@@ + repeated string source_input = 3; +} + +//@@ +//@@.. cpp:var:: message ModelVersionPolicy +//@@ +//@@ Policy indicating which versions of a model should be made +//@@ available by the inference server. +//@@ +message ModelVersionPolicy +{ + //@@ .. cpp:var:: message Latest + //@@ + //@@ Serve only the latest version(s) of a model. This is + //@@ the default policy. + //@@ + message Latest + { + //@@ .. cpp:var:: uint32 num_versions + //@@ + //@@ Serve only the 'num_versions' highest-numbered versions. T + //@@ The default value of 'num_versions' is 1, indicating that by + //@@ default only the single highest-number version of a + //@@ model will be served. + //@@ + uint32 num_versions = 1; + } + + //@@ .. cpp:var:: message All + //@@ + //@@ Serve all versions of the model. + //@@ + message All {} + + //@@ .. cpp:var:: message Specific + //@@ + //@@ Serve only specific versions of the model. + //@@ + message Specific + { + //@@ .. cpp:var:: int64 versions (repeated) + //@@ + //@@ The specific versions of the model that will be served. + //@@ + repeated int64 versions = 1; + } + + //@@ .. cpp:var:: oneof policy_choice + //@@ + //@@ Each model must implement only a single version policy. The + //@@ default policy is 'Latest'. + //@@ + oneof policy_choice + { + //@@ .. cpp:var:: Latest latest + //@@ + //@@ Serve only latest version(s) of the model. + //@@ + Latest latest = 1; + + //@@ .. cpp:var:: All all + //@@ + //@@ Serve all versions of the model. + //@@ + All all = 2; + + //@@ .. cpp:var:: Specific specific + //@@ + //@@ Serve only specific version(s) of the model. + //@@ + Specific specific = 3; + } +} + +//@@ +//@@.. cpp:var:: message ModelOptimizationPolicy +//@@ +//@@ Optimization settings for a model. These settings control if/how a +//@@ model is optimized and prioritized by the backend framework when +//@@ it is loaded. +//@@ +message ModelOptimizationPolicy +{ + //@@ + //@@ .. cpp:var:: message Graph + //@@ + //@@ Enable generic graph optimization of the model. If not specified + //@@ the framework's default level of optimization is used. Supports + //@@ TensorFlow graphdef and savedmodel and Onnx models. For TensorFlow + //@@ causes XLA to be enabled/disabled for the model. For Onnx defaults + //@@ to enabling all optimizations, -1 enables only basic optimizations, + //@@ +1 enables only basic and extended optimizations. + //@@ + message Graph + { + //@@ .. cpp:var:: int32 level + //@@ + //@@ The optimization level. Defaults to 0 (zero) if not specified. + //@@ + //@@ - -1: Disabled + //@@ - 0: Framework default + //@@ - 1+: Enable optimization level (greater values indicate + //@@ higher optimization levels) + //@@ + int32 level = 1; + } + + //@@ + //@@ .. cpp:enum:: ModelPriority + //@@ + //@@ Model priorities. A model will be given scheduling and execution + //@@ preference over models at lower priorities. Current model + //@@ priorities only work for TensorRT models. + //@@ + enum ModelPriority { + //@@ .. cpp:enumerator:: ModelPriority::PRIORITY_DEFAULT = 0 + //@@ + //@@ The default model priority. + //@@ + PRIORITY_DEFAULT = 0; + + //@@ .. cpp:enumerator:: ModelPriority::PRIORITY_MAX = 1 + //@@ + //@@ The maximum model priority. + //@@ + PRIORITY_MAX = 1; + + //@@ .. cpp:enumerator:: ModelPriority::PRIORITY_MIN = 2 + //@@ + //@@ The minimum model priority. + //@@ + PRIORITY_MIN = 2; + } + + //@@ + //@@ .. cpp:var:: message Cuda + //@@ + //@@ CUDA-specific optimization settings. + //@@ + message Cuda + { + //@@ .. cpp:var:: message GraphSpec + //@@ + //@@ Specification of the CUDA graph to be captured. + //@@ + message GraphSpec + { + //@@ .. cpp:var:: message Dims + //@@ + //@@ Specification of tensor dimension. + //@@ + message Shape + { + //@@ .. cpp:var:: int64 dim (repeated) + //@@ + //@@ The dimension. + //@@ + repeated int64 dim = 1; + } + + message LowerBound + { + //@@ .. cpp:var:: int32 batch_size + //@@ + //@@ The batch size of the CUDA graph. If 'max_batch_size' is 0, + //@@ 'batch_size' must be set to 0. Otherwise, 'batch_size' must + //@@ be set to value between 1 and 'max_batch_size'. + //@@ + int32 batch_size = 1; + + //@@ .. cpp:var:: map input + //@@ + //@@ The specification of the inputs. 'Shape' is the shape of + //@@ the input without batching dimension. + //@@ + map input = 2; + } + + //@@ .. cpp:var:: int32 batch_size + //@@ + //@@ The batch size of the CUDA graph. If 'max_batch_size' is 0, + //@@ 'batch_size' must be set to 0. Otherwise, 'batch_size' must + //@@ be set to value between 1 and 'max_batch_size'. + //@@ + int32 batch_size = 1; + + //@@ .. cpp:var:: map input + //@@ + //@@ The specification of the inputs. 'Shape' is the shape of the + //@@ input without batching dimension. + //@@ + map input = 2; + + //@@ .. cpp:var:: LowerBound graph_lower_bound + //@@ + //@@ Specify the lower bound of the CUDA graph. Optional. + //@@ If specified, the graph can be used for input shapes and + //@@ batch sizes that are in closed interval between the lower + //@@ bound specification and graph specification. For dynamic + //@@ shape model, this allows CUDA graphs to be launched + //@@ frequently without capturing all possible shape combinations. + //@@ However, using graph for shape combinations different from + //@@ the one used for capturing introduces uninitialized data for + //@@ execution and it may distort the inference result if + //@@ the model is sensitive to uninitialized data. + //@@ + LowerBound graph_lower_bound = 3; + } + + //@@ .. cpp:var:: bool graphs + //@@ + //@@ Use CUDA graphs API to capture model operations and execute + //@@ them more efficiently. Default value is false. + //@@ Currently only recognized by TensorRT backend. + //@@ + bool graphs = 1; + + //@@ .. cpp:var:: bool busy_wait_events + //@@ + //@@ Use busy-waiting to synchronize CUDA events to achieve minimum + //@@ latency from event complete to host thread to be notified, with + //@@ the cost of high CPU load. Default value is false. + //@@ Currently only recognized by TensorRT backend. + //@@ + bool busy_wait_events = 2; + + //@@ .. cpp:var:: GraphSpec graph_spec (repeated) + //@@ + //@@ Specification of the CUDA graph to be captured. If not specified + //@@ and 'graphs' is true, the default CUDA graphs will be captured + //@@ based on model settings. + //@@ Currently only recognized by TensorRT backend. + //@@ + repeated GraphSpec graph_spec = 3; + + //@@ .. cpp:var:: bool output_copy_stream + //@@ + //@@ Uses a CUDA stream separate from the inference stream to copy the + //@@ output to host. However, be aware that setting this option to + //@@ true will lead to an increase in the memory consumption of the + //@@ model as Triton will allocate twice as much GPU memory for its + //@@ I/O tensor buffers. Default value is false. + //@@ Currently only recognized by TensorRT backend. + //@@ + bool output_copy_stream = 4; + } + + //@@ + //@@ .. cpp:var:: message ExecutionAccelerators + //@@ + //@@ Specify the preferred execution accelerators to be used to execute + //@@ the model. Currently only recognized by ONNX Runtime backend and + //@@ TensorFlow backend. + //@@ + //@@ For ONNX Runtime backend, it will deploy the model with the execution + //@@ accelerators by priority, the priority is determined based on the + //@@ order that they are set, i.e. the provider at the front has highest + //@@ priority. Overall, the priority will be in the following order: + //@@ (if instance is on GPU) + //@@ CUDA Execution Provider (if instance is on GPU) + //@@ + //@@ Default CPU Execution Provider + //@@ + message ExecutionAccelerators + { + //@@ + //@@ .. cpp:var:: message Accelerator + //@@ + //@@ Specify the accelerator to be used to execute the model. + //@@ Accelerator with the same name may accept different parameters + //@@ depending on the backends. + //@@ + message Accelerator + { + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the execution accelerator. + //@@ + string name = 1; + + //@@ .. cpp:var:: map parameters + //@@ + //@@ Additional paremeters used to configure the accelerator. + //@@ + map parameters = 2; + } + + //@@ .. cpp:var:: Accelerator gpu_execution_accelerator (repeated) + //@@ + //@@ The preferred execution provider to be used if the model instance + //@@ is deployed on GPU. + //@@ + //@@ For ONNX Runtime backend, possible value is "tensorrt" as name, + //@@ and no parameters are required. + //@@ + //@@ For TensorFlow backend, possible values are "tensorrt", + //@@ "auto_mixed_precision", "gpu_io". + //@@ + //@@ For "tensorrt", the following parameters can be specified: + //@@ "precision_mode": The precision used for optimization. + //@@ Allowed values are "FP32" and "FP16". Default value is "FP32". + //@@ + //@@ "max_cached_engines": The maximum number of cached TensorRT + //@@ engines in dynamic TensorRT ops. Default value is 100. + //@@ + //@@ "minimum_segment_size": The smallest model subgraph that will + //@@ be considered for optimization by TensorRT. Default value is 3. + //@@ + //@@ "max_workspace_size_bytes": The maximum GPU memory the model + //@@ can use temporarily during execution. Default value is 1GB. + //@@ + //@@ For "auto_mixed_precision", no parameters are required. If set, + //@@ the model will try to use FP16 for better performance. + //@@ This optimization can not be set with "tensorrt". + //@@ + //@@ For "gpu_io", no parameters are required. If set, the model will + //@@ be executed using TensorFlow Callable API to set input and output + //@@ tensors in GPU memory if possible, which can reduce data transfer + //@@ overhead if the model is used in ensemble. However, the Callable + //@@ object will be created on model creation and it will request all + //@@ outputs for every model execution, which may impact the + //@@ performance if a request does not require all outputs. This + //@@ optimization will only take affect if the model instance is + //@@ created with KIND_GPU. + //@@ + repeated Accelerator gpu_execution_accelerator = 1; + + //@@ .. cpp:var:: Accelerator cpu_execution_accelerator (repeated) + //@@ + //@@ The preferred execution provider to be used if the model instance + //@@ is deployed on CPU. + //@@ + //@@ For ONNX Runtime backend, possible value is "openvino" as name, + //@@ and no parameters are required. + //@@ + repeated Accelerator cpu_execution_accelerator = 2; + } + + //@@ + //@@ .. cpp:var:: message PinnedMemoryBuffer + //@@ + //@@ Specify whether to use a pinned memory buffer when transferring data + //@@ between non-pinned system memory and GPU memory. Using a pinned + //@@ memory buffer for system from/to GPU transfers will typically provide + //@@ increased performance. For example, in the common use case where the + //@@ request provides inputs and delivers outputs via non-pinned system + //@@ memory, if the model instance accepts GPU IOs, the inputs will be + //@@ processed by two copies: from non-pinned system memory to pinned + //@@ memory, and from pinned memory to GPU memory. Similarly, pinned + //@@ memory will be used for delivering the outputs. + //@@ + message PinnedMemoryBuffer + { + //@@ .. cpp:var:: bool enable + //@@ + //@@ Use pinned memory buffer. Default is true. + //@@ + bool enable = 1; + } + + //@@ .. cpp:var:: Graph graph + //@@ + //@@ The graph optimization setting for the model. Optional. + //@@ + Graph graph = 1; + + //@@ .. cpp:var:: ModelPriority priority + //@@ + //@@ The priority setting for the model. Optional. + //@@ + ModelPriority priority = 2; + + //@@ .. cpp:var:: Cuda cuda + //@@ + //@@ CUDA-specific optimization settings. Optional. + //@@ + Cuda cuda = 3; + + //@@ .. cpp:var:: ExecutionAccelerators execution_accelerators + //@@ + //@@ The accelerators used for the model. Optional. + //@@ + ExecutionAccelerators execution_accelerators = 4; + + //@@ .. cpp:var:: PinnedMemoryBuffer input_pinned_memory + //@@ + //@@ Use pinned memory buffer when the data transfer for inputs + //@@ is between GPU memory and non-pinned system memory. + //@@ Default is true. + //@@ + PinnedMemoryBuffer input_pinned_memory = 5; + + //@@ .. cpp:var:: PinnedMemoryBuffer output_pinned_memory + //@@ + //@@ Use pinned memory buffer when the data transfer for outputs + //@@ is between GPU memory and non-pinned system memory. + //@@ Default is true. + //@@ + PinnedMemoryBuffer output_pinned_memory = 6; + + //@@ .. cpp:var:: uint32 gather_kernel_buffer_threshold + //@@ + //@@ The backend may use a gather kernel to gather input data if the + //@@ device has direct access to the source buffer and the destination + //@@ buffer. In such case, the gather kernel will be used only if the + //@@ number of buffers to be gathered is greater or equal to + //@@ the specifed value. If 0, the gather kernel will be disabled. + //@@ Default value is 0. + //@@ Currently only recognized by TensorRT backend. + //@@ + uint32 gather_kernel_buffer_threshold = 7; + + //@@ .. cpp:var:: bool eager_batching + //@@ + //@@ Start preparing the next batch before the model instance is ready + //@@ for the next inference. This option can be used to overlap the + //@@ batch preparation with model execution, with the trade-off that + //@@ the next batch might be smaller than what it could have been. + //@@ Default value is false. + //@@ Currently only recognized by TensorRT backend. + //@@ + bool eager_batching = 8; +} + +//@@ +//@@.. cpp:var:: message ModelQueuePolicy +//@@ +//@@ Queue policy for inference requests. +//@@ +message ModelQueuePolicy +{ + //@@ + //@@ .. cpp:enum:: TimeoutAction + //@@ + //@@ The action applied to timed-out requests. + //@@ + enum TimeoutAction { + //@@ .. cpp:enumerator:: Action::REJECT = 0 + //@@ + //@@ Reject the request and return error message accordingly. + //@@ + REJECT = 0; + + //@@ .. cpp:enumerator:: Action::DELAY = 1 + //@@ + //@@ Delay the request until all other requests at the same + //@@ (or higher) priority levels that have not reached their timeouts + //@@ are processed. A delayed request will eventually be processed, + //@@ but may be delayed indefinitely due to newly arriving requests. + //@@ + DELAY = 1; + } + + //@@ + //@@ .. cpp:var:: TimeoutAction timeout_action + //@@ + //@@ The action applied to timed-out request. + //@@ The default action is REJECT. + //@@ + TimeoutAction timeout_action = 1; + + //@@ + //@@ .. cpp:var:: uint64 default_timeout_microseconds + //@@ + //@@ The default timeout for every request, in microseconds. + //@@ The default value is 0 which indicates that no timeout is set. + //@@ + uint64 default_timeout_microseconds = 2; + + //@@ + //@@ .. cpp:var:: bool allow_timeout_override + //@@ + //@@ Whether individual request can override the default timeout value. + //@@ When true, individual requests can set a timeout that is less than + //@@ the default timeout value but may not increase the timeout. + //@@ The default value is false. + //@@ + bool allow_timeout_override = 3; + + //@@ + //@@ .. cpp:var:: uint32 max_queue_size + //@@ + //@@ The maximum queue size for holding requests. A request will be + //@@ rejected immediately if it can't be enqueued because the queue is + //@@ full. The default value is 0 which indicates that no maximum + //@@ queue size is enforced. + //@@ + uint32 max_queue_size = 4; +} + +//@@ +//@@.. cpp:var:: message ModelDynamicBatching +//@@ +//@@ Dynamic batching configuration. These settings control how dynamic +//@@ batching operates for the model. +//@@ +message ModelDynamicBatching +{ + //@@ .. cpp:var:: int32 preferred_batch_size (repeated) + //@@ + //@@ Preferred batch sizes for dynamic batching. If a batch of one of + //@@ these sizes can be formed it will be executed immediately. If + //@@ not specified a preferred batch size will be chosen automatically + //@@ based on model and GPU characteristics. + //@@ + repeated int32 preferred_batch_size = 1; + + //@@ .. cpp:var:: uint64 max_queue_delay_microseconds + //@@ + //@@ The maximum time, in microseconds, a request will be delayed in + //@@ the scheduling queue to wait for additional requests for + //@@ batching. Default is 0. + //@@ + uint64 max_queue_delay_microseconds = 2; + + //@@ .. cpp:var:: bool preserve_ordering + //@@ + //@@ Should the dynamic batcher preserve the ordering of responses to + //@@ match the order of requests received by the scheduler. Default is + //@@ false. If true, the responses will be returned in the same order as + //@@ the order of requests sent to the scheduler. If false, the responses + //@@ may be returned in arbitrary order. This option is specifically + //@@ needed when a sequence of related inference requests (i.e. inference + //@@ requests with the same correlation ID) are sent to the dynamic + //@@ batcher to ensure that the sequence responses are in the correct + //@@ order. + //@@ + bool preserve_ordering = 3; + + //@@ .. cpp:var:: uint32 priority_levels + //@@ + //@@ The number of priority levels to be enabled for the model, + //@@ the priority level starts from 1 and 1 is the highest priority. + //@@ Requests are handled in priority order with all priority 1 requests + //@@ processed before priority 2, all priority 2 requests processed before + //@@ priority 3, etc. Requests with the same priority level will be + //@@ handled in the order that they are received. + //@@ + uint32 priority_levels = 4; + + //@@ .. cpp:var:: uint32 default_priority_level + //@@ + //@@ The priority level used for requests that don't specify their + //@@ priority. The value must be in the range [ 1, 'priority_levels' ]. + //@@ + uint32 default_priority_level = 5; + + //@@ .. cpp:var:: ModelQueuePolicy default_queue_policy + //@@ + //@@ The default queue policy used for requests that don't require + //@@ priority handling and requests that specify priority levels where + //@@ there is no specific policy given. If not specified, a policy with + //@@ default field values will be used. + //@@ + ModelQueuePolicy default_queue_policy = 6; + + //@@ .. cpp:var:: map priority_queue_policy + //@@ + //@@ Specify the queue policy for the priority level. The default queue + //@@ policy will be used if a priority level doesn't specify a queue + //@@ policy. + //@@ + map priority_queue_policy = 7; +} + +//@@ +//@@.. cpp:var:: message ModelSequenceBatching +//@@ +//@@ Sequence batching configuration. These settings control how sequence +//@@ batching operates for the model. +//@@ +message ModelSequenceBatching +{ + //@@ .. cpp:var:: message Control + //@@ + //@@ A control is a signal that the sequence batcher uses to + //@@ communicate with a backend. + //@@ + message Control + { + //@@ + //@@ .. cpp:enum:: Kind + //@@ + //@@ The kind of the control. + //@@ + enum Kind { + //@@ .. cpp:enumerator:: Kind::CONTROL_SEQUENCE_START = 0 + //@@ + //@@ A new sequence is/is-not starting. If true a sequence is + //@@ starting, if false a sequence is continuing. Must + //@@ specify either int32_false_true, fp32_false_true or + //@@ bool_false_true for this control. This control is optional. + //@@ + CONTROL_SEQUENCE_START = 0; + + //@@ .. cpp:enumerator:: Kind::CONTROL_SEQUENCE_READY = 1 + //@@ + //@@ A sequence is/is-not ready for inference. If true the + //@@ input tensor data is valid and should be used. If false + //@@ the input tensor data is invalid and inferencing should + //@@ be "skipped". Must specify either int32_false_true, + //@@ fp32_false_true or bool_false_true for this control. This + //@@ control is optional. + //@@ + CONTROL_SEQUENCE_READY = 1; + + //@@ .. cpp:enumerator:: Kind::CONTROL_SEQUENCE_END = 2 + //@@ + //@@ A sequence is/is-not ending. If true a sequence is + //@@ ending, if false a sequence is continuing. Must specify + //@@ either int32_false_true, fp32_false_true or bool_false_true + //@@ for this control. This control is optional. + //@@ + CONTROL_SEQUENCE_END = 2; + + //@@ .. cpp:enumerator:: Kind::CONTROL_SEQUENCE_CORRID = 3 + //@@ + //@@ The correlation ID of the sequence. The correlation ID + //@@ is an uint64_t value that is communicated in whole or + //@@ in part by the tensor. The tensor's datatype must be + //@@ specified by data_type and must be TYPE_UINT64, TYPE_INT64, + //@@ TYPE_UINT32 or TYPE_INT32. If a 32-bit datatype is specified + //@@ the correlation ID will be truncated to the low-order 32 + //@@ bits. This control is optional. + //@@ + CONTROL_SEQUENCE_CORRID = 3; + } + + //@@ .. cpp:var:: Kind kind + //@@ + //@@ The kind of this control. + //@@ + Kind kind = 1; + + //@@ .. cpp:var:: int32 int32_false_true (repeated) + //@@ + //@@ The control's true and false setting is indicated by setting + //@@ a value in an int32 tensor. The tensor must be a + //@@ 1-dimensional tensor with size equal to the batch size of + //@@ the request. 'int32_false_true' must have two entries: the + //@@ first the false value and the second the true value. + //@@ + repeated int32 int32_false_true = 2; + + //@@ .. cpp:var:: float fp32_false_true (repeated) + //@@ + //@@ The control's true and false setting is indicated by setting + //@@ a value in a fp32 tensor. The tensor must be a + //@@ 1-dimensional tensor with size equal to the batch size of + //@@ the request. 'fp32_false_true' must have two entries: the + //@@ first the false value and the second the true value. + //@@ + repeated float fp32_false_true = 3; + + //@@ .. cpp:var:: bool bool_false_true (repeated) + //@@ + //@@ The control's true and false setting is indicated by setting + //@@ a value in a bool tensor. The tensor must be a + //@@ 1-dimensional tensor with size equal to the batch size of + //@@ the request. 'bool_false_true' must have two entries: the + //@@ first the false value and the second the true value. + //@@ + repeated bool bool_false_true = 5; + + //@@ .. cpp:var:: DataType data_type + //@@ + //@@ The control's datatype. + //@@ + DataType data_type = 4; + } + + //@@ .. cpp:var:: message ControlInput + //@@ + //@@ The sequence control values to communicate by a model input. + //@@ + message ControlInput + { + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the model input. + //@@ + string name = 1; + + //@@ .. cpp:var:: Control control (repeated) + //@@ + //@@ The control value(s) that should be communicated to the + //@@ model using this model input. + //@@ + repeated Control control = 2; + } + + //@@ + //@@ .. cpp:var:: message InitialState + //@@ + //@@ Settings used to initialize data for implicit state. + //@@ + message InitialState + { + //@@ .. cpp:var:: DataType data_type + //@@ + //@@ The data-type of the state. + //@@ + DataType data_type = 1; + + //@@ .. cpp:var:: int64 dims (repeated) + //@@ + //@@ The shape of the state tensor, not including the batch dimension. + //@@ + repeated int64 dims = 2; + + //@@ .. cpp:var:: oneof state_data + //@@ + //@@ Specify how the initial state data is generated. + //@@ + oneof state_data + { + //@@ + //@@ .. cpp:var:: bool zero_data + //@@ + //@@ The identifier for using zeros as initial state data. + //@@ Note that the value of 'zero_data' will not be checked, + //@@ instead, zero data will be used as long as the field is set. + //@@ + bool zero_data = 3; + + //@@ .. cpp:var:: string data_file + //@@ + //@@ The file whose content will be used as the initial data for + //@@ the state in row-major order. The file must be provided in + //@@ sub-directory 'initial_state' under the model directory. + //@@ + string data_file = 4; + } + + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the state initialization. + //@@ + string name = 5; + } + + //@@ .. cpp:var:: message State + //@@ + //@@ An input / output pair of tensors that carry state for the sequence. + //@@ + message State + { + //@@ .. cpp:var:: string input_name + //@@ + //@@ The name of the model state input. + //@@ + string input_name = 1; + + //@@ .. cpp:var:: string output_name + //@@ + //@@ The name of the model state output. + //@@ + string output_name = 2; + + //@@ .. cpp:var:: DataType data_type + //@@ + //@@ The data-type of the state. + //@@ + DataType data_type = 3; + + //@@ .. cpp:var:: int64 dim (repeated) + //@@ + //@@ The dimension. + //@@ + repeated int64 dims = 4; + + //@@ .. cpp:var:: InitialState initial_state (repeated) + //@@ + //@@ The optional field to specify the initial state for the model. + //@@ + repeated InitialState initial_state = 5; + } + + //@@ .. cpp:var:: message StrategyDirect + //@@ + //@@ The sequence batcher uses a specific, unique batch + //@@ slot for each sequence. All inference requests in a + //@@ sequence are directed to the same batch slot in the same + //@@ model instance over the lifetime of the sequence. This + //@@ is the default strategy. + //@@ + message StrategyDirect + { + //@@ .. cpp:var:: uint64 max_queue_delay_microseconds + //@@ + //@@ The maximum time, in microseconds, a candidate request + //@@ will be delayed in the sequence batch scheduling queue to + //@@ wait for additional requests for batching. Default is 0. + //@@ + uint64 max_queue_delay_microseconds = 1; + + //@@ .. cpp:var:: float minimum_slot_utilization + //@@ + //@@ The minimum slot utilization that must be satisfied to + //@@ execute the batch before 'max_queue_delay_microseconds' expires. + //@@ For example, a value of 0.5 indicates that the batch should be + //@@ executed as soon as 50% or more of the slots are ready even if + //@@ the 'max_queue_delay_microseconds' timeout has not expired. + //@@ The default is 0.0, indicating that a batch will be executed + //@@ before 'max_queue_delay_microseconds' timeout expires if at least + //@@ one batch slot is ready. 'max_queue_delay_microseconds' will be + //@@ ignored unless minimum_slot_utilization is set to a non-zero + //@@ value. + //@@ + float minimum_slot_utilization = 2; + } + + //@@ .. cpp:var:: message StrategyOldest + //@@ + //@@ The sequence batcher maintains up to 'max_candidate_sequences' + //@@ candidate sequences. 'max_candidate_sequences' can be greater + //@@ than the model's 'max_batch_size'. For inferencing the batcher + //@@ chooses from the candidate sequences up to 'max_batch_size' + //@@ inference requests. Requests are chosen in an oldest-first + //@@ manner across all candidate sequences. A given sequence is + //@@ not guaranteed to be assigned to the same batch slot for + //@@ all inference requests of that sequence. + //@@ + message StrategyOldest + { + //@@ .. cpp:var:: int32 max_candidate_sequences + //@@ + //@@ Maximum number of candidate sequences that the batcher + //@@ maintains. Excess seqences are kept in an ordered backlog + //@@ and become candidates when existing candidate sequences + //@@ complete. + //@@ + int32 max_candidate_sequences = 1; + + //@@ .. cpp:var:: int32 preferred_batch_size (repeated) + //@@ + //@@ Preferred batch sizes for dynamic batching of candidate + //@@ sequences. If a batch of one of these sizes can be formed + //@@ it will be executed immediately. If not specified a + //@@ preferred batch size will be chosen automatically + //@@ based on model and GPU characteristics. + //@@ + repeated int32 preferred_batch_size = 2; + + //@@ .. cpp:var:: uint64 max_queue_delay_microseconds + //@@ + //@@ The maximum time, in microseconds, a candidate request + //@@ will be delayed in the dynamic batch scheduling queue to + //@@ wait for additional requests for batching. Default is 0. + //@@ + uint64 max_queue_delay_microseconds = 3; + } + + //@@ .. cpp:var:: oneof strategy_choice + //@@ + //@@ The strategy used by the sequence batcher. Default strategy + //@@ is 'direct'. + //@@ + oneof strategy_choice + { + //@@ .. cpp:var:: StrategyDirect direct + //@@ + //@@ StrategyDirect scheduling strategy. + //@@ + StrategyDirect direct = 3; + + //@@ .. cpp:var:: StrategyOldest oldest + //@@ + //@@ StrategyOldest scheduling strategy. + //@@ + StrategyOldest oldest = 4; + } + + //@@ .. cpp:var:: uint64 max_sequence_idle_microseconds + //@@ + //@@ The maximum time, in microseconds, that a sequence is allowed to + //@@ be idle before it is aborted. The inference server considers a + //@@ sequence idle when it does not have any inference request queued + //@@ for the sequence. If this limit is exceeded, the inference server + //@@ will free the sequence slot allocated by the sequence and make it + //@@ available for another sequence. If not specified (or specified as + //@@ zero) a default value of 1000000 (1 second) is used. + //@@ + uint64 max_sequence_idle_microseconds = 1; + + //@@ .. cpp:var:: ControlInput control_input (repeated) + //@@ + //@@ The model input(s) that the server should use to communicate + //@@ sequence start, stop, ready and similar control values to the + //@@ model. + //@@ + repeated ControlInput control_input = 2; + + //@@ .. cpp:var:: State state (repeated) + //@@ + //@@ The optional state that can be stored in Triton for performing + //@@ inference requests on a sequence. Each sequence holds an implicit + //@@ state local to itself. The output state tensor provided by the + //@@ model in 'output_name' field of the current inference request will + //@@ be transferred as an input tensor named 'input_name' in the next + //@@ request of the same sequence. The input state of the first request + //@@ in the sequence contains garbage data. + //@@ + repeated State state = 5; +} + +//@@ +//@@.. cpp:var:: message ModelEnsembling +//@@ +//@@ Model ensembling configuration. These settings specify the models that +//@@ compose the ensemble and how data flows between the models. +//@@ +message ModelEnsembling +{ + //@@ .. cpp:var:: message Step + //@@ + //@@ Each step specifies a model included in the ensemble, + //@@ maps ensemble tensor names to the model input tensors, + //@@ and maps model output tensors to ensemble tensor names + //@@ + message Step + { + //@@ .. cpp:var:: string model_name + //@@ + //@@ The name of the model to execute for this step of the ensemble. + //@@ + string model_name = 1; + + //@@ .. cpp:var:: int64 model_version + //@@ + //@@ The version of the model to use for inference. If -1 + //@@ the latest/most-recent version of the model is used. + //@@ + int64 model_version = 2; + + //@@ .. cpp:var:: map input_map + //@@ + //@@ Map from name of an input tensor on this step's model to ensemble + //@@ tensor name. The ensemble tensor must have the same data type and + //@@ shape as the model input. Each model input must be assigned to + //@@ one ensemble tensor, but the same ensemble tensor can be assigned + //@@ to multiple model inputs. + //@@ + map input_map = 3; + + //@@ .. cpp:var:: map output_map + //@@ + //@@ Map from name of an output tensor on this step's model to ensemble + //@@ tensor name. The data type and shape of the ensemble tensor will + //@@ be inferred from the model output. It is optional to assign all + //@@ model outputs to ensemble tensors. One ensemble tensor name + //@@ can appear in an output map only once. + //@@ + map output_map = 4; + } + + //@@ .. cpp:var:: Step step (repeated) + //@@ + //@@ The models and the input / output mappings used within the ensemble. + //@@ + repeated Step step = 1; +} + +//@@ +//@@.. cpp:var:: message ModelParameter +//@@ +//@@ A model parameter. +//@@ +message ModelParameter +{ + //@@ .. cpp:var:: string string_value + //@@ + //@@ The string value of the parameter. + //@@ + string string_value = 1; +} + +//@@ +//@@.. cpp:var:: message ModelWarmup +//@@ +//@@ Settings used to construct the request sample for model warmup. +//@@ +message ModelWarmup +{ + //@@ + //@@ .. cpp:var:: message Input + //@@ + //@@ Meta data associated with an input. + //@@ + message Input + { + //@@ .. cpp:var:: DataType data_type + //@@ + //@@ The data-type of the input. + //@@ + DataType data_type = 1; + + //@@ .. cpp:var:: int64 dims (repeated) + //@@ + //@@ The shape of the input tensor, not including the batch dimension. + //@@ + repeated int64 dims = 2; + + //@@ .. cpp:var:: oneof input_data_type + //@@ + //@@ Specify how the input data is generated. If the input has STRING + //@@ data type and 'random_data' is set, the data generation will fall + //@@ back to 'zero_data'. + //@@ + oneof input_data_type + { + //@@ + //@@ .. cpp:var:: bool zero_data + //@@ + //@@ The identifier for using zeros as input data. Note that the + //@@ value of 'zero_data' will not be checked, instead, zero data + //@@ will be used as long as the field is set. + //@@ + bool zero_data = 3; + + //@@ + //@@ .. cpp:var:: bool random_data + //@@ + //@@ The identifier for using random data as input data. Note that + //@@ the value of 'random_data' will not be checked, instead, + //@@ random data will be used as long as the field is set. + //@@ + bool random_data = 4; + + //@@ .. cpp:var:: string input_data_file + //@@ + //@@ The file whose content will be used as raw input data in + //@@ row-major order. The file must be provided in a sub-directory + //@@ 'warmup' under the model directory. The file contents should be + //@@ in binary format. For TYPE_STRING data-type, an element is + //@@ represented by a 4-byte unsigned integer giving the length + //@@ followed by the actual bytes. + //@@ + string input_data_file = 5; + } + } + + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the request sample. + //@@ + string name = 1; + + //@@ .. cpp:var:: uint32 batch_size + //@@ + //@@ The batch size of the inference request. This must be >= 1. For + //@@ models that don't support batching, batch_size must be 1. If + //@@ batch_size > 1, the 'inputs' specified below will be duplicated to + //@@ match the batch size requested. + //@@ + uint32 batch_size = 2; + + //@@ .. cpp:var:: map inputs + //@@ + //@@ The warmup meta data associated with every model input, including + //@@ control tensors. + //@@ + map inputs = 3; + + //@@ .. cpp:var:: uint32 count + //@@ + //@@ The number of iterations that this warmup sample will be executed. + //@@ For example, if this field is set to 2, 2 model executions using this + //@@ sample will be scheduled for warmup. Default value is 0 which + //@@ indicates that this sample will be used only once. + //@@ Note that for sequence model, 'count' may not work well + //@@ because the model often expect a valid sequence of requests which + //@@ should be represented by a series of warmup samples. 'count > 1' + //@@ essentially "resends" one of the sample, which may invalidate the + //@@ sequence and result in unexpected warmup failure. + //@@ + uint32 count = 4; +} + +//@@ +//@@ .. cpp:var:: message ModelOperations +//@@ +//@@ The metadata of libraries providing custom operations for this model. +//@@ +message ModelOperations +{ + //@@ .. cpp:var:: string op_library_filename (repeated) + //@@ + //@@ Optional paths of the libraries providing custom operations for + //@@ this model. Valid only for ONNX models. + //@@ + repeated string op_library_filename = 1; +} + +//@@ +//@@ .. cpp:var:: message ModelTransactionPolicy +//@@ +//@@ The specification that describes the nature of transactions +//@@ to be expected from the model. +//@@ +message ModelTransactionPolicy +{ + //@@ .. cpp:var:: bool decoupled + //@@ + //@@ Indicates whether responses generated by the model are decoupled with + //@@ the requests issued to it, which means the number of responses + //@@ generated by model may differ from number of requests issued, and + //@@ that the responses may be out of order relative to the order of + //@@ requests. The default is false, which means the model will generate + //@@ exactly one response for each request. + //@@ + bool decoupled = 1; +} + +//@@ +//@@.. cpp:var:: message ModelRepositoryAgents +//@@ +//@@ The repository agents for the model. +//@@ +message ModelRepositoryAgents +{ + //@@ + //@@ .. cpp:var:: message Agent + //@@ + //@@ A repository agent that should be invoked for the specified + //@@ repository actions for this model. + //@@ + message Agent + { + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the agent. + //@@ + string name = 1; + + //@@ .. cpp:var:: map parameters + //@@ + //@@ The parameters for the agent. + //@@ + map parameters = 2; + } + + //@@ + //@@ .. cpp:var:: Agent agents (repeated) + //@@ + //@@ The ordered list of agents for the model. These agents will be + //@@ invoked in order to respond to repository actions occuring for the + //@@ model. + //@@ + repeated Agent agents = 1; +} + +//@@ +//@@.. cpp:var:: message ModelResponseCache +//@@ +//@@ The response cache setting for the model. +//@@ +message ModelResponseCache +{ + //@@ + //@@ .. cpp::var:: bool enable + //@@ + //@@ Whether or not to use response cache for the model. If True, the + //@@ responses from the model are cached and when identical request + //@@ is encountered, instead of going through the model execution, + //@@ the response from the cache is utilized. By default, response + //@@ cache is disabled for the models. + //@@ + bool enable = 1; +} + +//@@ +//@@.. cpp:var:: message ModelConfig +//@@ +//@@ A model configuration. +//@@ +message ModelConfig +{ + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the model. + //@@ + string name = 1; + + //@@ .. cpp:var:: string platform + //@@ + //@@ The framework for the model. Possible values are + //@@ "tensorrt_plan", "tensorflow_graphdef", + //@@ "tensorflow_savedmodel", "onnxruntime_onnx", + //@@ "pytorch_libtorch". + //@@ + string platform = 2; + + //@@ .. cpp:var:: string backend + //@@ + //@@ The backend used by the model. + //@@ + string backend = 17; + + //@@ .. cpp:var:: ModelVersionPolicy version_policy + //@@ + //@@ Policy indicating which version(s) of the model will be served. + //@@ + ModelVersionPolicy version_policy = 3; + + //@@ .. cpp:var:: int32 max_batch_size + //@@ + //@@ Maximum batch size allowed for inference. This can only decrease + //@@ what is allowed by the model itself. A max_batch_size value of 0 + //@@ indicates that batching is not allowed for the model and the + //@@ dimension/shape of the input and output tensors must exactly + //@@ match what is specified in the input and output configuration. A + //@@ max_batch_size value > 0 indicates that batching is allowed and + //@@ so the model expects the input tensors to have an additional + //@@ initial dimension for the batching that is not specified in the + //@@ input (for example, if the model supports batched inputs of + //@@ 2-dimensional tensors then the model configuration will specify + //@@ the input shape as [ X, Y ] but the model will expect the actual + //@@ input tensors to have shape [ N, X, Y ]). For max_batch_size > 0 + //@@ returned outputs will also have an additional initial dimension + //@@ for the batch. + //@@ + int32 max_batch_size = 4; + + //@@ .. cpp:var:: ModelInput input (repeated) + //@@ + //@@ The inputs request by the model. + //@@ + repeated ModelInput input = 5; + + //@@ .. cpp:var:: ModelOutput output (repeated) + //@@ + //@@ The outputs produced by the model. + //@@ + repeated ModelOutput output = 6; + + //@@ .. cpp:var:: BatchInput batch_input (repeated) + //@@ + //@@ The model input(s) that the server should use to communicate + //@@ batch related values to the model. + //@@ + repeated BatchInput batch_input = 20; + + //@@ .. cpp:var:: BatchOutput batch_output (repeated) + //@@ + //@@ The outputs produced by the model that requires special handling + //@@ by the model backend. + //@@ + repeated BatchOutput batch_output = 21; + + //@@ .. cpp:var:: ModelOptimizationPolicy optimization + //@@ + //@@ Optimization configuration for the model. If not specified + //@@ then default optimization policy is used. + //@@ + ModelOptimizationPolicy optimization = 12; + + //@@ .. cpp:var:: oneof scheduling_choice + //@@ + //@@ The scheduling policy for the model. If not specified the + //@@ default scheduling policy is used for the model. The default + //@@ policy is to execute each inference request independently. + //@@ + oneof scheduling_choice + { + //@@ .. cpp:var:: ModelDynamicBatching dynamic_batching + //@@ + //@@ If specified, enables the dynamic-batching scheduling + //@@ policy. With dynamic-batching the scheduler may group + //@@ together independent requests into a single batch to + //@@ improve inference throughput. + //@@ + ModelDynamicBatching dynamic_batching = 11; + + //@@ .. cpp:var:: ModelSequenceBatching sequence_batching + //@@ + //@@ If specified, enables the sequence-batching scheduling + //@@ policy. With sequence-batching, inference requests + //@@ with the same correlation ID are routed to the same + //@@ model instance. Multiple sequences of inference requests + //@@ may be batched together into a single batch to + //@@ improve inference throughput. + //@@ + ModelSequenceBatching sequence_batching = 13; + + //@@ .. cpp:var:: ModelEnsembling ensemble_scheduling + //@@ + //@@ If specified, enables the model-ensembling scheduling + //@@ policy. With model-ensembling, inference requests + //@@ will be processed according to the specification, such as an + //@@ execution sequence of models. The input specified in this model + //@@ config will be the input for the ensemble, and the output + //@@ specified will be the output of the ensemble. + //@@ + ModelEnsembling ensemble_scheduling = 15; + } + + //@@ .. cpp:var:: ModelInstanceGroup instance_group (repeated) + //@@ + //@@ Instances of this model. If not specified, one instance + //@@ of the model will be instantiated on each available GPU. + //@@ + repeated ModelInstanceGroup instance_group = 7; + + //@@ .. cpp:var:: string default_model_filename + //@@ + //@@ Optional filename of the model file to use if a + //@@ compute-capability specific model is not specified in + //@@ :cpp:var:`cc_model_filenames`. If not specified the default name + //@@ is 'model.graphdef', 'model.savedmodel', 'model.plan' or + //@@ 'model.pt' depending on the model type. + //@@ + string default_model_filename = 8; + + //@@ .. cpp:var:: map cc_model_filenames + //@@ + //@@ Optional map from CUDA compute capability to the filename of + //@@ the model that supports that compute capability. The filename + //@@ refers to a file within the model version directory. + //@@ + map cc_model_filenames = 9; + + //@@ .. cpp:var:: map metric_tags + //@@ + //@@ Optional metric tags. User-specific key-value pairs for metrics + //@@ reported for this model. These tags are applied to the metrics + //@@ reported on the HTTP metrics port. + //@@ + map metric_tags = 10; + + //@@ .. cpp:var:: map parameters + //@@ + //@@ Optional model parameters. User-specified parameter values. + //@@ + map parameters = 14; + + //@@ .. cpp:var:: ModelWarmup model_warmup (repeated) + //@@ + //@@ Warmup setting of this model. If specified, all instances + //@@ will be run with the request samples in sequence before + //@@ serving the model. + //@@ This field can only be specified if the model is not an ensemble + //@@ model. + //@@ + repeated ModelWarmup model_warmup = 16; + + //@@ .. cpp:var:: ModelOperations model_operations + //@@ + //@@ Optional metadata of the libraries providing custom operations for + //@@ this model. + //@@ + ModelOperations model_operations = 18; + + //@@ .. cpp:var:: ModelTransactionPolicy model_transaction_policy + //@@ + //@@ Optional specification that describes the nature of transactions + //@@ to be expected from the model. + //@@ + ModelTransactionPolicy model_transaction_policy = 19; + + //@@ .. cpp:var:: ModelRepositoryAgents model_repository_agents + //@@ + //@@ Optional specification of the agent(s) that should be invoked + //@@ with repository actions are performed for this model. + //@@ + ModelRepositoryAgents model_repository_agents = 23; + + //@@ .. cpp:var:: ModelResponseCache response_cache + //@@ + //@@ Optional setting for utilizing the response cache for this + //@@ model. + //@@ + ModelResponseCache response_cache = 24; +} diff --git a/3rdparty/common-r22.12/src/async_work_queue.cc b/3rdparty/common-r22.12/src/async_work_queue.cc new file mode 100644 index 0000000000000000000000000000000000000000..ebd56b9a3c5b2876ac3caeebfff0ceb75083e715 --- /dev/null +++ b/3rdparty/common-r22.12/src/async_work_queue.cc @@ -0,0 +1,97 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "triton/common/async_work_queue.h" + +namespace triton { namespace common { + +AsyncWorkQueue::~AsyncWorkQueue() +{ + GetSingleton()->thread_pool_.reset(); +} + +AsyncWorkQueue* +AsyncWorkQueue::GetSingleton() +{ + static AsyncWorkQueue singleton; + return &singleton; +} + +Error +AsyncWorkQueue::Initialize(size_t worker_count) +{ + if (worker_count < 1) { + return Error( + Error::Code::INVALID_ARG, + "Async work queue must be initialized with positive 'worker_count'"); + } + + static std::mutex init_mtx; + std::lock_guard lk(init_mtx); + + if (GetSingleton()->thread_pool_) { + return Error( + Error::Code::ALREADY_EXISTS, + "Async work queue has been initialized with " + + std::to_string(GetSingleton()->thread_pool_->Size()) + + " 'worker_count'"); + } + + GetSingleton()->thread_pool_.reset(new ThreadPool(worker_count)); + return Error::Success; +} + +size_t +AsyncWorkQueue::WorkerCount() +{ + if (!GetSingleton()->thread_pool_) { + return 0; + } + return GetSingleton()->thread_pool_->Size(); +} + +Error +AsyncWorkQueue::AddTask(std::function&& task) +{ + if (!GetSingleton()->thread_pool_) { + return Error( + Error::Code::UNAVAILABLE, + "Async work queue must be initialized before adding task"); + } + GetSingleton()->thread_pool_->Enqueue(std::move(task)); + + return Error::Success; +} + +void +AsyncWorkQueue::Reset() +{ + // Reconstruct the singleton to reset it + GetSingleton()->~AsyncWorkQueue(); + new (GetSingleton()) AsyncWorkQueue(); +} + +}} // namespace triton::common diff --git a/3rdparty/common-r22.12/src/error.cc b/3rdparty/common-r22.12/src/error.cc new file mode 100644 index 0000000000000000000000000000000000000000..b6da386fa162f246a682dd6f9aef34c22ceffd77 --- /dev/null +++ b/3rdparty/common-r22.12/src/error.cc @@ -0,0 +1,68 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "triton/common/error.h" + +namespace triton { namespace common { + +const Error Error::Success(Error::Code::SUCCESS); + +std::string +Error::AsString() const +{ + std::string str(CodeString(code_)); + str += ": " + msg_; + return str; +} + +const char* +Error::CodeString(const Code code) +{ + switch (code) { + case Error::Code::SUCCESS: + return "OK"; + case Error::Code::UNKNOWN: + return "Unknown"; + case Error::Code::INTERNAL: + return "Internal"; + case Error::Code::NOT_FOUND: + return "Not found"; + case Error::Code::INVALID_ARG: + return "Invalid argument"; + case Error::Code::UNAVAILABLE: + return "Unavailable"; + case Error::Code::UNSUPPORTED: + return "Unsupported"; + case Error::Code::ALREADY_EXISTS: + return "Already exists"; + default: + break; + } + + return ""; +} + +}} // namespace triton::common diff --git a/3rdparty/common-r22.12/src/logging.cc b/3rdparty/common-r22.12/src/logging.cc new file mode 100644 index 0000000000000000000000000000000000000000..67b01ba8ba1bd1253e7b1156f7346b662714c660 --- /dev/null +++ b/3rdparty/common-r22.12/src/logging.cc @@ -0,0 +1,147 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "triton/common/logging.h" + +#ifdef _WIN32 +// suppress the min and max definitions in Windef.h. +#define NOMINMAX +#include +#else +#include +#include +#include +#include +#endif +#include +#include +#include + +namespace triton { namespace common { + +Logger gLogger_; + +Logger::Logger() + : enables_{true, true, true}, vlevel_(0), format_(Format::kDEFAULT) +{ +} + +void +Logger::Log(const std::string& msg) +{ + const std::lock_guard lock(mutex_); + if (file_stream_.is_open()) { + file_stream_ << msg << std::endl; + } else { + std::cerr << msg << std::endl; + } +} + +void +Logger::Flush() +{ + std::cerr << std::flush; +} + + +const std::vector LogMessage::level_name_{'E', 'W', 'I'}; + +LogMessage::LogMessage(const char* file, int line, uint32_t level) +{ + std::string path(file); + size_t pos = path.rfind('/'); + if (pos != std::string::npos) { + path = path.substr(pos + 1, std::string::npos); + } + + // 'L' below is placeholder for showing log level + switch (gLogger_.LogFormat()) { + case Logger::Format::kDEFAULT: { + // LMMDD hh:mm:ss.ssssss +#ifdef _WIN32 + SYSTEMTIME system_time; + GetSystemTime(&system_time); + stream_ << level_name_[std::min(level, (uint32_t)Level::kINFO)] + << std::setfill('0') << std::setw(2) << system_time.wMonth + << std::setw(2) << system_time.wDay << ' ' << std::setw(2) + << system_time.wHour << ':' << std::setw(2) << system_time.wMinute + << ':' << std::setw(2) << system_time.wSecond << '.' + << std::setw(6) << system_time.wMilliseconds * 1000 << ' ' + << static_cast(GetCurrentProcessId()) << ' ' << path + << ':' << line << "] "; +#else + struct timeval tv; + gettimeofday(&tv, NULL); + struct tm tm_time; + gmtime_r(((time_t*)&(tv.tv_sec)), &tm_time); + stream_ << level_name_[std::min(level, (uint32_t)Level::kINFO)] + << std::setfill('0') << std::setw(2) << (tm_time.tm_mon + 1) + << std::setw(2) << tm_time.tm_mday << ' ' << std::setw(2) + << tm_time.tm_hour << ':' << std::setw(2) << tm_time.tm_min << ':' + << std::setw(2) << tm_time.tm_sec << '.' << std::setw(6) + << tv.tv_usec << ' ' << static_cast(getpid()) << ' ' + << path << ':' << line << "] "; +#endif + break; + } + case Logger::Format::kISO8601: { + // YYYY-MM-DDThh:mm:ssZ L +#ifdef _WIN32 + SYSTEMTIME system_time; + GetSystemTime(&system_time); + stream_ << system_time.wYear << '-' << std::setfill('0') << std::setw(2) + << system_time.wMonth << '-' << std::setw(2) << system_time.wDay + << 'T' << std::setw(2) << system_time.wHour << ':' << std::setw(2) + << system_time.wMinute << ':' << std::setw(2) + << system_time.wSecond << "Z " + << level_name_[std::min(level, (uint32_t)Level::kINFO)] << ' ' + << static_cast(GetCurrentProcessId()) << ' ' << path + << ':' << line << "] "; +#else + struct timeval tv; + gettimeofday(&tv, NULL); + struct tm tm_time; + gmtime_r(((time_t*)&(tv.tv_sec)), &tm_time); + stream_ << (tm_time.tm_year + 1900) << '-' << std::setfill('0') + << std::setw(2) << (tm_time.tm_mon + 1) << '-' << std::setw(2) + << tm_time.tm_mday << 'T' << std::setw(2) << tm_time.tm_hour + << ':' << std::setw(2) << tm_time.tm_min << ':' << std::setw(2) + << tm_time.tm_sec << "Z " + << level_name_[std::min(level, (uint32_t)Level::kINFO)] << ' ' + << static_cast(getpid()) << ' ' << path << ':' << line + << "] "; +#endif + break; + } + } +} + +LogMessage::~LogMessage() +{ + gLogger_.Log(stream_.str()); +} + +}} // namespace triton::common diff --git a/3rdparty/common-r22.12/src/model_config.cc b/3rdparty/common-r22.12/src/model_config.cc new file mode 100644 index 0000000000000000000000000000000000000000..e459ef071f2cb35c7e27136ef74e9008f4b712a7 --- /dev/null +++ b/3rdparty/common-r22.12/src/model_config.cc @@ -0,0 +1,443 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "triton/common/model_config.h" + +namespace triton { namespace common { + +bool +IsFixedSizeDataType(const inference::DataType dtype) +{ + return dtype != inference::DataType::TYPE_STRING; +} + +size_t +GetDataTypeByteSize(const inference::DataType dtype) +{ + switch (dtype) { + case inference::DataType::TYPE_BOOL: + return 1; + case inference::DataType::TYPE_UINT8: + return 1; + case inference::DataType::TYPE_UINT16: + return 2; + case inference::DataType::TYPE_UINT32: + return 4; + case inference::DataType::TYPE_UINT64: + return 8; + case inference::DataType::TYPE_INT8: + return 1; + case inference::DataType::TYPE_INT16: + return 2; + case inference::DataType::TYPE_INT32: + return 4; + case inference::DataType::TYPE_INT64: + return 8; + case inference::DataType::TYPE_FP16: + return 2; + case inference::DataType::TYPE_FP32: + return 4; + case inference::DataType::TYPE_FP64: + return 8; + case inference::DataType::TYPE_STRING: + return 0; + case inference::DataType::TYPE_BF16: + return 2; + default: + break; + } + + return 0; +} + +int64_t +GetElementCount(const DimsList& dims) +{ + bool first = true; + int64_t cnt = 0; + for (auto dim : dims) { + if (dim == WILDCARD_DIM) { + return -1; + } + + if (first) { + cnt = dim; + first = false; + } else { + cnt *= dim; + } + } + + return cnt; +} + +int64_t +GetElementCount(const std::vector& dims) +{ + bool first = true; + int64_t cnt = 0; + for (auto dim : dims) { + if (dim == WILDCARD_DIM) { + return -1; + } + + if (first) { + cnt = dim; + first = false; + } else { + cnt *= dim; + } + } + + return cnt; +} + +int64_t +GetElementCount(const inference::ModelInput& mio) +{ + return GetElementCount(mio.dims()); +} + +int64_t +GetElementCount(const inference::ModelOutput& mio) +{ + return GetElementCount(mio.dims()); +} + +int64_t +GetByteSize(const inference::DataType& dtype, const DimsList& dims) +{ + size_t dt_size = GetDataTypeByteSize(dtype); + if (dt_size == 0) { + return -1; + } + + int64_t cnt = GetElementCount(dims); + if (cnt == -1) { + return -1; + } + + return cnt * dt_size; +} + +int64_t +GetByteSize(const inference::DataType& dtype, const std::vector& dims) +{ + size_t dt_size = GetDataTypeByteSize(dtype); + if (dt_size == 0) { + return -1; + } + + int64_t cnt = GetElementCount(dims); + if (cnt == -1) { + return -1; + } + + return cnt * dt_size; +} + +int64_t +GetByteSize( + const int batch_size, const inference::DataType& dtype, + const DimsList& dims) +{ + if (dims.size() == 0) { + return batch_size * GetDataTypeByteSize(dtype); + } + + int64_t bs = GetByteSize(dtype, dims); + if (bs == -1) { + return -1; + } + + return std::max(1, batch_size) * bs; +} + +int64_t +GetByteSize( + const int batch_size, const inference::DataType& dtype, + const std::vector& dims) +{ + if (dims.size() == 0) { + return batch_size * GetDataTypeByteSize(dtype); + } + + int64_t bs = GetByteSize(dtype, dims); + if (bs == -1) { + return -1; + } + + return std::max(1, batch_size) * bs; +} + +int64_t +GetByteSize(const inference::ModelInput& mio) +{ + return GetByteSize(mio.data_type(), mio.dims()); +} + +int64_t +GetByteSize(const inference::ModelOutput& mio) +{ + return GetByteSize(mio.data_type(), mio.dims()); +} + +int +GetCpuNiceLevel(const inference::ModelConfig& config) +{ + int nice = SCHEDULER_DEFAULT_NICE; + if (config.has_optimization()) { + switch (config.optimization().priority()) { + case inference::ModelOptimizationPolicy::PRIORITY_MAX: + nice = 0; + break; + case inference::ModelOptimizationPolicy::PRIORITY_MIN: + nice = 19; + break; + default: + nice = SCHEDULER_DEFAULT_NICE; + break; + } + } + + return nice; +} + +bool +CompareDims(const DimsList& dims0, const DimsList& dims1) +{ + if (dims0.size() != dims1.size()) { + return false; + } + + for (int i = 0; i < dims0.size(); ++i) { + if (dims0[i] != dims1[i]) { + return false; + } + } + + return true; +} + +bool +CompareDims( + const std::vector& dims0, const std::vector& dims1) +{ + if (dims0.size() != dims1.size()) { + return false; + } + + for (size_t i = 0; i < dims0.size(); ++i) { + if (dims0[i] != dims1[i]) { + return false; + } + } + + return true; +} + +bool +CompareDimsWithWildcard(const DimsList& dims0, const DimsList& dims1) +{ + if (dims0.size() != dims1.size()) { + return false; + } + + for (int i = 0; i < dims0.size(); ++i) { + if ((dims0[i] != WILDCARD_DIM) && (dims1[i] != WILDCARD_DIM) && + (dims0[i] != dims1[i])) { + return false; + } + } + + return true; +} + +bool +CompareDimsWithWildcard( + const DimsList& dims0, const std::vector& dims1) +{ + if (dims0.size() != (int64_t)dims1.size()) { + return false; + } + + for (int i = 0; i < dims0.size(); ++i) { + if ((dims0[i] != WILDCARD_DIM) && (dims1[i] != WILDCARD_DIM) && + (dims0[i] != dims1[i])) { + return false; + } + } + + return true; +} + +std::string +DimsListToString(const DimsList& dims) +{ + bool first = true; + + std::string str("["); + for (const auto& dim : dims) { + if (!first) { + str += ","; + } + str += std::to_string(dim); + first = false; + } + + str += "]"; + return str; +} + +std::string +DimsListToString(const std::vector& dims, const int start_idx) +{ + int idx = 0; + + std::string str("["); + for (const auto& dim : dims) { + if (idx >= start_idx) { + if (idx > start_idx) { + str += ","; + } + str += std::to_string(dim); + } + + idx++; + } + + str += "]"; + return str; +} + +const char* +DataTypeToProtocolString(const inference::DataType dtype) +{ + switch (dtype) { + case inference::DataType::TYPE_BOOL: + return "BOOL"; + case inference::DataType::TYPE_UINT8: + return "UINT8"; + case inference::DataType::TYPE_UINT16: + return "UINT16"; + case inference::DataType::TYPE_UINT32: + return "UINT32"; + case inference::DataType::TYPE_UINT64: + return "UINT64"; + case inference::DataType::TYPE_INT8: + return "INT8"; + case inference::DataType::TYPE_INT16: + return "INT16"; + case inference::DataType::TYPE_INT32: + return "INT32"; + case inference::DataType::TYPE_INT64: + return "INT64"; + case inference::DataType::TYPE_FP16: + return "FP16"; + case inference::DataType::TYPE_FP32: + return "FP32"; + case inference::DataType::TYPE_FP64: + return "FP64"; + case inference::DataType::TYPE_STRING: + return "BYTES"; + case inference::DataType::TYPE_BF16: + return "BF16"; + default: + break; + } + + return ""; +} + +inference::DataType +ProtocolStringToDataType(const std::string& dtype) +{ + return ProtocolStringToDataType(dtype.c_str(), dtype.size()); +} + +inference::DataType +ProtocolStringToDataType(const char* dtype, size_t len) +{ + if (len < 4 || len > 6) { + return inference::DataType::TYPE_INVALID; + } + + if ((*dtype == 'I') && (len != 6)) { + if ((dtype[1] == 'N') && (dtype[2] == 'T')) { + if ((dtype[3] == '8') && (len == 4)) { + return inference::DataType::TYPE_INT8; + } else if ((dtype[3] == '1') && (dtype[4] == '6')) { + return inference::DataType::TYPE_INT16; + } else if ((dtype[3] == '3') && (dtype[4] == '2')) { + return inference::DataType::TYPE_INT32; + } else if ((dtype[3] == '6') && (dtype[4] == '4')) { + return inference::DataType::TYPE_INT64; + } + } + } else if ((*dtype == 'U') && (len != 4)) { + if ((dtype[1] == 'I') && (dtype[2] == 'N') && (dtype[3] == 'T')) { + if ((dtype[4] == '8') && (len == 5)) { + return inference::DataType::TYPE_UINT8; + } else if ((dtype[4] == '1') && (dtype[5] == '6')) { + return inference::DataType::TYPE_UINT16; + } else if ((dtype[4] == '3') && (dtype[5] == '2')) { + return inference::DataType::TYPE_UINT32; + } else if ((dtype[4] == '6') && (dtype[5] == '4')) { + return inference::DataType::TYPE_UINT64; + } + } + } else if ((*dtype == 'F') && (dtype[1] == 'P') && (len == 4)) { + if ((dtype[2] == '1') && (dtype[3] == '6')) { + return inference::DataType::TYPE_FP16; + } else if ((dtype[2] == '3') && (dtype[3] == '2')) { + return inference::DataType::TYPE_FP32; + } else if ((dtype[2] == '6') && (dtype[3] == '4')) { + return inference::DataType::TYPE_FP64; + } + } else if (*dtype == 'B') { + switch (dtype[1]) { + case 'Y': + if (!strcmp(dtype + 2, "TES")) { + return inference::DataType::TYPE_STRING; + } + break; + case 'O': + if (!strcmp(dtype + 2, "OL")) { + return inference::DataType::TYPE_BOOL; + } + break; + case 'F': + if (!strcmp(dtype + 2, "16")) { + return inference::DataType::TYPE_BF16; + } + break; + } + } + + return inference::DataType::TYPE_INVALID; +} + +}} // namespace triton::common diff --git a/3rdparty/common-r22.12/src/table_printer.cc b/3rdparty/common-r22.12/src/table_printer.cc new file mode 100644 index 0000000000000000000000000000000000000000..779f49921f8cd43c8e1543c465d0819ae1413989 --- /dev/null +++ b/3rdparty/common-r22.12/src/table_printer.cc @@ -0,0 +1,261 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "triton/common/table_printer.h" + +#ifdef _WIN32 +// suppress the min and max definitions in Windef.h. +#define NOMINMAX +#include +#else +#include +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include + +namespace triton { namespace common { + +// +// ASCII table printer. +// +void +TablePrinter::InsertRow(const std::vector& row) +{ + std::vector> table_row; + + // Number of lines in each field in the record + size_t max_height = 0; + + // Update max length of data items in each row + for (size_t i = 0; i < row.size(); ++i) { + table_row.push_back(std::vector{}); + std::stringstream ss(row[i]); + std::string line; + + size_t max_width = 0; + while (std::getline(ss, line, '\n')) { + table_row[i].push_back(line); + if (line.size() > max_width) + max_width = line.size(); + } + + if (max_width > max_widths_[i]) + max_widths_[i] = max_width; + + size_t number_of_lines = table_row[i].size(); + if (max_height < number_of_lines) + max_height = number_of_lines; + } + + max_heights_.push_back(max_height); + data_.emplace_back(table_row); +} + +void +TablePrinter::FairShare() +{ + // initialize original index locations + size_t array_size = max_widths_.size(); + std::vector idx(array_size); + iota(idx.begin(), idx.end(), 0); + + stable_sort(idx.begin(), idx.end(), [this](size_t i1, size_t i2) { + return this->max_widths_[i1] < this->max_widths_[i2]; + }); + + size_t loop_index = 1; + for (auto itr = idx.begin(); itr != idx.end(); ++itr) { + // If a column is not using all the space allocated to it + if (max_widths_[*itr] < shares_[*itr]) { + float excess = shares_[*itr] - max_widths_[*itr]; + shares_[*itr] -= excess; + + if (itr == idx.end() - 1) + break; + auto update_itr = idx.begin() + (itr - idx.begin() + 1); + + // excess amount of unused space that must be distributed evenly to the + // next columns + float excess_per_column = excess / (array_size - loop_index); + + for (; update_itr != idx.end(); ++update_itr) { + shares_[*update_itr] += excess_per_column; + excess -= excess_per_column; + } + } + ++loop_index; + } + + // Remove any decimal shares + for (auto itr = idx.begin(); itr != idx.end(); ++itr) { + shares_[*itr] = (size_t)shares_[*itr]; + } + + // For each record + for (size_t i = 0; i < data_.size(); i++) { + auto current_row = data_[i]; + + // For each field in the record + for (size_t j = 0; j < current_row.size(); j++) { + // For each line in the record + for (size_t line_index = 0; line_index < current_row[j].size(); + line_index++) { + std::string line = current_row[j][line_index]; + size_t num_rows = (line.size() + shares_[j] - 1) / shares_[j]; + + // If the number of rows required for this record is larger than 1, we + // will break that line and put it in multiple lines + if (num_rows > 1) { + // Remove the multi-line field, it will be replaced by the line + // that can fits the column size + data_[i][j].erase(data_[i][j].begin() + line_index); + for (size_t k = 0; k < num_rows; k++) { + size_t start_index = + std::min((size_t)(k * shares_[j]), line.size()); + size_t end_index = + std::min((size_t)((k + 1) * shares_[j]), line.size()); + data_[i][j].insert( + data_[i][j].begin() + line_index + k, + line.substr(start_index, end_index - start_index)); + } + + // We need to advance the index for the splitted lines. + line_index += num_rows - 1; + } + + if (max_heights_[i] < (num_rows - 1 + current_row[j].size())) + max_heights_[i] += num_rows - 1; + } + } + } +} + +void +TablePrinter::AddRow(std::stringstream& table, size_t row_index) +{ + auto row = data_[row_index]; + size_t max_height = max_heights_[row_index]; + + for (size_t j = 0; j < max_height; j++) { + table << "|" << std::left; + + for (size_t i = 0; i < row.size(); i++) { + if (j < row[i].size()) + table << " " << std::setw(shares_[i]) << row[i][j] << " |"; + else + table << " " << std::setw(shares_[i]) << " " + << " |"; + } + + // Do not add new line if this is the last row of this record + if (j != max_height - 1) + table << "\n"; + } + table << "\n"; +} + +void +TablePrinter::AddRowDivider(std::stringstream& table) +{ + table << "+"; + for (const auto& share : shares_) { + for (size_t i = 0; i < share + 2; i++) table << "-"; + table << "+"; + } + table << "\n"; +} + +std::string +TablePrinter::PrintTable() +{ + std::stringstream table; + table << "\n"; + + FairShare(); + + AddRowDivider(table); + // Add table headers + AddRow(table, 0); + AddRowDivider(table); + + for (size_t j = 1; j < data_.size(); j++) { + AddRow(table, j); + } + + AddRowDivider(table); + + return table.str(); +} + +// TablePrinter will take the ownership of `headers`. +TablePrinter::TablePrinter(const std::vector& headers) +{ + // terminal size + size_t column_size = 500; +#ifdef _WIN32 + CONSOLE_SCREEN_BUFFER_INFO csbi; + int ret = GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi); + if (ret && (csbi.dwSize.X != 0)) { + column_size = csbi.dwSize.X; + } +#else + struct winsize terminal_size; + int status = ioctl(STDOUT_FILENO, TIOCGWINSZ, &terminal_size); + if ((status == 0) && (terminal_size.ws_col != 0)) { + column_size = terminal_size.ws_col; + } +#endif + + for (size_t i = 0; i < headers.size(); ++i) { + max_widths_.emplace_back(0); + } + + // Calculate fair share of every column + size_t number_of_columns = headers.size(); + + // Terminal width is the actual terminal width minus two times spaces + // required before and after each column and number of columns plus 1 for + // the pipes between the columns + size_t terminal_width = + column_size - (2 * number_of_columns) - (number_of_columns + 1); + int equal_share = terminal_width / headers.size(); + + for (size_t i = 0; i < headers.size(); ++i) { + shares_.emplace_back(equal_share); + terminal_width -= equal_share; + } + + InsertRow(headers); +} + +}} // namespace triton::common diff --git a/3rdparty/common-r22.12/src/thread_pool.cc b/3rdparty/common-r22.12/src/thread_pool.cc new file mode 100644 index 0000000000000000000000000000000000000000..8f53db71d0430294208b157f9a5c2627fdeac296 --- /dev/null +++ b/3rdparty/common-r22.12/src/thread_pool.cc @@ -0,0 +1,97 @@ +// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "triton/common/thread_pool.h" +#include + +namespace triton { namespace common { + +ThreadPool::ThreadPool(size_t thread_count) +{ + if (!thread_count) { + throw std::invalid_argument("Thread count must be greater than zero."); + } + + // Define infinite loop for each thread to wait for a task to complete + const auto worker_loop = [this]() { + while (true) { + Task task; + { + std::unique_lock lk(queue_mtx_); + // Wake if there's a task to do, or the pool has been stopped. + cv_.wait(lk, [&]() { return !task_queue_.empty() || stop_; }); + // Exit condition + if (stop_ && task_queue_.empty()) { + break; + } + task = std::move(task_queue_.front()); + task_queue_.pop(); + } + + // Execute task - ensure function has a valid target + if (task) { + task(); + } + } + }; + + workers_.reserve(thread_count); + for (size_t i = 0; i < thread_count; ++i) { + workers_.emplace_back(worker_loop); + } +} + +ThreadPool::~ThreadPool() +{ + { + std::lock_guard lk(queue_mtx_); + // Signal to each worker that it should exit loop when tasks are finished + stop_ = true; + } + // Wake all threads to clean up + cv_.notify_all(); + for (auto& t : workers_) { + t.join(); + } +} + +void +ThreadPool::Enqueue(Task&& task) +{ + { + std::lock_guard lk(queue_mtx_); + // Don't accept more work if pool is shutting down + if (stop_) { + return; + } + task_queue_.push(std::move(task)); + } + // Only wake one thread per task + // Todo: DLIS-3859 if ThreadPool gets used more. + cv_.notify_one(); +} + +}} // namespace triton::common diff --git a/3rdparty/common-r22.12/tools/format.py b/3rdparty/common-r22.12/tools/format.py new file mode 100644 index 0000000000000000000000000000000000000000..84649d3c2f9501a5f6caf1b85baa12b90e4c3e03 --- /dev/null +++ b/3rdparty/common-r22.12/tools/format.py @@ -0,0 +1,116 @@ +#!/usr/bin/python + +# Copyright (c) 2018-2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import os +import subprocess +import yapf + +FLAGS = None +FORMAT_EXTS = ('proto', 'cc', 'cu', 'h') +SKIP_PATHS = ('tools',) + + +def visit(path): + if FLAGS.verbose: + print("visiting " + path) + + valid_ext = False + python_file = False + for ext in FORMAT_EXTS: + if path.endswith('.' + ext): + valid_ext = True + break + if path.endswith('.py'): + valid_ext = True + python_file = True + if not valid_ext: + if FLAGS.verbose: + print("skipping due to extension: " + path) + return True + + for skip in SKIP_PATHS: + if path.startswith(skip): + if FLAGS.verbose: + print("skipping due to path prefix: " + path) + return True + if python_file: + yapf.yapflib.yapf_api.FormatFile(path, + in_place=True, + style_config='google') + return True + else: + args = ['clang-format-6.0', '--style=file', '-i'] + if FLAGS.verbose: + args.append('-verbose') + args.append(path) + + ret = subprocess.call(args) + if ret != 0: + print("format failed for " + path) + return False + + return True + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-v', + '--verbose', + action="store_true", + required=False, + default=False, + help='Enable verbose output') + parser.add_argument('paths', + type=str, + nargs='*', + default=None, + help='Directories or files to format') + FLAGS = parser.parse_args() + + # Check the version of yapf. Needs a consistent version + # of yapf to prevent unneccessary changes in the code. + if (yapf.__version__ != '0.30.0'): + print("Needs yapf 0.30.0, but got yapf {}".format(yapf.__version__)) + + if (FLAGS.paths is None) or (len(FLAGS.paths) == 0): + parser.print_help() + exit(1) + + ret = True + for path in FLAGS.paths: + if not os.path.isdir(path): + if not visit(path): + ret = False + else: + for root, dirs, files in os.walk(path): + for name in files: + if not visit(os.path.join(root, name)): + ret = False + + exit(0 if ret else 1) diff --git a/3rdparty/common-r22.12/tools/pre-commit b/3rdparty/common-r22.12/tools/pre-commit new file mode 100644 index 0000000000000000000000000000000000000000..5e8ba370716b1981e96ee2c45165326bbba3a1e3 --- /dev/null +++ b/3rdparty/common-r22.12/tools/pre-commit @@ -0,0 +1,56 @@ +#!/bin/bash +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +############################################################################### +# +# Git pre-commit hook for Triton related projects +# +# To install this hook for a project, copy "pre-commit" and "format.py" into +# ".git/hooks/" directory of the project +# +############################################################################### + +############################################################################### +# +# Run formatter script +# +############################################################################### + +# Repo root +GIT_REPO_ROOT=$(git rev-parse --show-toplevel) + +PYTHON_CMD=python3 +FORMATTER_PY=${GIT_REPO_ROOT}/.git/hooks/format.py + +CHANGED_FILES="$(git --no-pager diff --name-status --no-color --cached | awk '{ if (match($1, /R[0-9]+/)) { print $3 } else if ($1 != "D") { print $2 } }')" + +echo "Running Python auto-format..." +for CHANGED_FILE in $CHANGED_FILES; +do + ${PYTHON_CMD} ${FORMATTER_PY} ${GIT_REPO_ROOT}/${CHANGED_FILE} + git add ${GIT_REPO_ROOT}/${CHANGED_FILE} +done diff --git a/3rdparty/core-r22.12/.clang-format b/3rdparty/core-r22.12/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..98c649734c29e0b1d134dae65be9bc08a14b4ba5 --- /dev/null +++ b/3rdparty/core-r22.12/.clang-format @@ -0,0 +1,37 @@ +--- +BasedOnStyle: Google + +IndentWidth: 2 +ContinuationIndentWidth: 4 +UseTab: Never +MaxEmptyLinesToKeep: 2 + +SortIncludes: true +CompactNamespaces: true +ReflowComments: true + +DerivePointerAlignment: false +PointerAlignment: Left + +AllowShortIfStatementsOnASingleLine: false +AllowShortBlocksOnASingleLine: false +AllowShortFunctionsOnASingleLine: Inline + +AlwaysBreakAfterReturnType: TopLevelDefinitions +AlignAfterOpenBracket: AlwaysBreak +BreakBeforeBraces: Custom +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: true + AfterNamespace: false + AfterStruct: false + AfterUnion: false + BeforeCatch: true + +BinPackArguments: true +BinPackParameters: true +ConstructorInitializerAllOnOneLineOrOnePerLine: false + +IndentCaseLabels: true \ No newline at end of file diff --git a/3rdparty/core-r22.12/.gitignore b/3rdparty/core-r22.12/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0e9f099a2eef4742716637e3cce3a45f7053b021 --- /dev/null +++ b/3rdparty/core-r22.12/.gitignore @@ -0,0 +1,3 @@ +/build +/.vscode +*.so diff --git a/3rdparty/core-r22.12/LICENSE b/3rdparty/core-r22.12/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..237621c146a3b2f3b43a26eba937393b8e1a6f0c --- /dev/null +++ b/3rdparty/core-r22.12/LICENSE @@ -0,0 +1,25 @@ +Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of NVIDIA CORPORATION nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/3rdparty/core-r22.12/README.md b/3rdparty/core-r22.12/README.md new file mode 100644 index 0000000000000000000000000000000000000000..457844d1307c05713a5f6a957ea7a77c47312ec2 --- /dev/null +++ b/3rdparty/core-r22.12/README.md @@ -0,0 +1,104 @@ + + +[![License](https://img.shields.io/badge/License-BSD3-lightgrey.svg)](https://opensource.org/licenses/BSD-3-Clause) + +# Triton Inference Server Core + +This repository holds the source code and headers for the library that +implements the core functionality of Triton. The *core* library can be +built as described below and used directly via its [C +API](https://github.com/triton-inference-server/server/blob/main/docs/customization_guide/inference_protocols.md#in-process-triton-server-api). To +be useful the core library must be paired with one or more backends. +You can learn more about backends in the [backend +repo](https://github.com/triton-inference-server/backend). + +Typically you do not build or use the core library on its own, but as +part of the *tritonserver* executable. The *tritonserver* executable +is built in the [server +repo](https://github.com/triton-inference-server/server) as described +in the [server build +documentation](https://github.com/triton-inference-server/server/blob/main/docs/customization_guide/build.md). + +Ask questions or report problems in the main Triton [issues +page](https://github.com/triton-inference-server/server/issues). + +## Build the Triton Core Library + +Before building the Triton core library, your build system must +install the required dependencies described in the [build +documentation](https://github.com/triton-inference-server/server/blob/main/docs/customization_guide/build.md). For +example, if you are building the core library with GPU support +(-DTRITON_ENABLE_GPU=ON), then you must install the CUDA, cuDNN, and +TensorRT dependencies required for the version of Triton you are +building. + +To build, first clone the release branch matching the Triton release +you are interest in (*rxx.yy*), or the *main* branch to build the +top-of-tree. The Triton core library is built with CMake. + +``` +$ mkdir build +$ cd build +$ cmake -DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install -DTRITON_CORE_HEADERS_ONLY=OFF .. +$ make install +``` + +When the build completes, the install directory will contain the +Triton core shared library (install/lib/libtritonserver.so on Linux, +install/bin/tritonserver.dll on Windows), and the core library headers +files in install/include/triton/core. + +### Build a Release Branch + +The following required Triton repositories will be pulled and used in +the build. By default the "main" branch/tag will be used for each repo +but the listed CMake argument can be used to override. + +* triton-inference-server/third_party: -DTRITON_THIRD_PARTY_REPO_TAG=[tag] +* triton-inference-server/common: -DTRITON_COMMON_REPO_TAG=[tag] + +You will need to override if you are building from a release +branch. For example, if you are building the r22.03 version of Triton, +you would clone the r22.03 branch of the core repo and use the +following cmake command. + +``` +$ cmake -DTRITON_THIRD_PARTY_REPO_TAG=r22.03 -DTRITON_COMMON_REPO_TAG=r22.03 -DTRITON_CORE_HEADERS_ONLY=OFF .. +``` + +### Build Options + +The [CMakeLists.txt](CMakeLists.txt) file contains the options +available when build the core library. For example, to build the core +library with the default settings plus S3 cloud storage and ensembling +support use the following command. + +``` +$ cmake -DTRITON_CORE_HEADERS_ONLY=OFF -DTRITON_ENABLE_S3=ON -DTRITON_ENABLE_ENSEMBLE=ON .. +``` diff --git a/3rdparty/core-r22.12/cmake/TritonCoreConfig.cmake.in b/3rdparty/core-r22.12/cmake/TritonCoreConfig.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..05ba9db1845980877d8814171c2d8fad6fc61a08 --- /dev/null +++ b/3rdparty/core-r22.12/cmake/TritonCoreConfig.cmake.in @@ -0,0 +1,37 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +include(CMakeFindDependencyMacro) + +get_filename_component( + TRITONCORE_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH +) + +list(APPEND CMAKE_MODULE_PATH ${TRITONCORE_CMAKE_DIR}) + +if(NOT TARGET TritonCore::triton-core-serverapi) + include("${TRITONCORE_CMAKE_DIR}/TritonCoreTargets.cmake") +endif() diff --git a/3rdparty/core-r22.12/include/triton/core/tritonbackend.h b/3rdparty/core-r22.12/include/triton/core/tritonbackend.h new file mode 100644 index 0000000000000000000000000000000000000000..9d800183abbf5511d61a10e036d1d6142cbc0625 --- /dev/null +++ b/3rdparty/core-r22.12/include/triton/core/tritonbackend.h @@ -0,0 +1,1410 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include "triton/core/tritonserver.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef _COMPILING_TRITONBACKEND +#if defined(_MSC_VER) +#define TRITONBACKEND_DECLSPEC __declspec(dllexport) +#define TRITONBACKEND_ISPEC __declspec(dllimport) +#elif defined(__GNUC__) +#define TRITONBACKEND_DECLSPEC __attribute__((__visibility__("default"))) +#define TRITONBACKEND_ISPEC +#else +#define TRITONBACKEND_DECLSPEC +#define TRITONBACKEND_ISPEC +#endif +#else +#if defined(_MSC_VER) +#define TRITONBACKEND_DECLSPEC __declspec(dllimport) +#define TRITONBACKEND_ISPEC __declspec(dllexport) +#else +#define TRITONBACKEND_DECLSPEC +#define TRITONBACKEND_ISPEC +#endif +#endif + +struct TRITONBACKEND_MemoryManager; +struct TRITONBACKEND_Input; +struct TRITONBACKEND_Output; +struct TRITONBACKEND_State; +struct TRITONBACKEND_Request; +struct TRITONBACKEND_ResponseFactory; +struct TRITONBACKEND_Response; +struct TRITONBACKEND_Backend; +struct TRITONBACKEND_Model; +struct TRITONBACKEND_ModelInstance; +struct TRITONBACKEND_BackendAttribute; + +/// +/// TRITONBACKEND API Version +/// +/// The TRITONBACKEND API is versioned with major and minor version +/// numbers. Any change to the API that does not impact backwards +/// compatibility (for example, adding a non-required function) +/// increases the minor version number. Any change that breaks +/// backwards compatibility (for example, deleting or changing the +/// behavior of a function) increases the major version number. A +/// backend should check that the API version used to compile the +/// backend is compatible with the API version of the Triton server +/// that it is running in. This is typically done by code similar to +/// the following which makes sure that the major versions are equal +/// and that the minor version of Triton is >= the minor version used +/// to build the backend. +/// +/// uint32_t api_version_major, api_version_minor; +/// TRITONBACKEND_ApiVersion(&api_version_major, &api_version_minor); +/// if ((api_version_major != TRITONBACKEND_API_VERSION_MAJOR) || +/// (api_version_minor < TRITONBACKEND_API_VERSION_MINOR)) { +/// return TRITONSERVER_ErrorNew( +/// TRITONSERVER_ERROR_UNSUPPORTED, +/// "triton backend API version does not support this backend"); +/// } +/// +#define TRITONBACKEND_API_VERSION_MAJOR 1 +#define TRITONBACKEND_API_VERSION_MINOR 10 + +/// Get the TRITONBACKEND API version supported by Triton. This value +/// can be compared against the TRITONBACKEND_API_VERSION_MAJOR and +/// TRITONBACKEND_API_VERSION_MINOR used to build the backend to +/// ensure that Triton is compatible with the backend. +/// +/// \param major Returns the TRITONBACKEND API major version supported +/// by Triton. +/// \param minor Returns the TRITONBACKEND API minor version supported +/// by Triton. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ApiVersion( + uint32_t* major, uint32_t* minor); + +/// TRITONBACKEND_ArtifactType +/// +/// The ways that the files that make up a backend or model are +/// communicated to the backend. +/// +/// TRITONBACKEND_ARTIFACT_FILESYSTEM: The model or backend +/// artifacts are made available to Triton via a locally +/// accessible filesystem. The backend can access these files +/// using an appropriate system API. +/// +typedef enum TRITONBACKEND_artifacttype_enum { + TRITONBACKEND_ARTIFACT_FILESYSTEM +} TRITONBACKEND_ArtifactType; + + +/// +/// TRITONBACKEND_MemoryManager +/// +/// Object representing an memory manager that is capable of +/// allocating and otherwise managing different memory types. For +/// improved performance Triton maintains pools for GPU and CPU-pinned +/// memory and the memory manager allows backends to access those +/// pools. +/// + +/// Allocate a contiguous block of memory of a specific type using a +/// memory manager. Two error codes have specific interpretations for +/// this function: +/// +/// TRITONSERVER_ERROR_UNSUPPORTED: Indicates that Triton is +/// incapable of allocating the requested memory type and memory +/// type ID. Requests for the memory type and ID will always fail +/// no matter 'byte_size' of the request. +/// +/// TRITONSERVER_ERROR_UNAVAILABLE: Indicates that Triton can +/// allocate the memory type and ID but that currently it cannot +/// allocate a contiguous block of memory of the requested +/// 'byte_size'. +/// +/// \param manager The memory manager. +/// \param buffer Returns the allocated memory. +/// \param memory_type The type of memory to allocate. +/// \param memory_type_id The ID associated with the memory type to +/// allocate. For GPU memory this indicates the device ID of the GPU +/// to allocate from. +/// \param byte_size The size of memory to allocate, in bytes. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_MemoryManagerAllocate( + TRITONBACKEND_MemoryManager* manager, void** buffer, + const TRITONSERVER_MemoryType memory_type, const int64_t memory_type_id, + const uint64_t byte_size); + +/// Free a buffer that was previously allocated with +/// TRITONBACKEND_MemoryManagerAllocate. The call must provide the +/// same values for 'memory_type' and 'memory_type_id' as were used +/// when the buffer was allocate or else the behavior is undefined. +/// +/// \param manager The memory manager. +/// \param buffer The allocated memory buffer to free. +/// \param memory_type The type of memory of the buffer. +/// \param memory_type_id The ID associated with the memory type of +/// the buffer. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_MemoryManagerFree( + TRITONBACKEND_MemoryManager* manager, void* buffer, + const TRITONSERVER_MemoryType memory_type, const int64_t memory_type_id); + +/// +/// TRITONBACKEND_Input +/// +/// Object representing an input tensor. +/// + +/// Get the name and properties of an input tensor. The returned +/// strings and other properties are owned by the input, not the +/// caller, and so should not be modified or freed. +/// +/// \param input The input tensor. +/// \param name If non-nullptr, returns the tensor name. +/// \param datatype If non-nullptr, returns the tensor datatype. +/// \param shape If non-nullptr, returns the tensor shape. +/// \param dim_count If non-nullptr, returns the number of dimensions +/// in the tensor shape. +/// \param byte_size If non-nullptr, returns the size of the available +/// data for the tensor, in bytes. This size reflects the actual data +/// available, and does not necessarily match what is +/// expected/required for the tensor given its shape and datatype. It +/// is the responsibility of the backend to handle mismatches in these +/// sizes appropriately. +/// \param buffer_count If non-nullptr, returns the number of buffers +/// holding the contents of the tensor. These buffers are accessed +/// using TRITONBACKEND_InputBuffer. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_InputProperties( + TRITONBACKEND_Input* input, const char** name, + TRITONSERVER_DataType* datatype, const int64_t** shape, + uint32_t* dims_count, uint64_t* byte_size, uint32_t* buffer_count); + +/// Get the name and properties of an input tensor associated with a given +/// host policy. If there are no input buffers for the specified host policy, +/// the properties of the fallback input buffers are returned. The returned +/// strings and other properties are owned by the input, not the caller, and so +/// should not be modified or freed. +/// +/// \param input The input tensor. +/// \param host_policy_name The host policy name. Fallback input properties +/// will be return if nullptr is provided. +/// \param name If non-nullptr, returns the tensor name. +/// \param datatype If non-nullptr, returns the tensor datatype. +/// \param shape If non-nullptr, returns the tensor shape. +/// \param dim_count If non-nullptr, returns the number of dimensions +/// in the tensor shape. +/// \param byte_size If non-nullptr, returns the size of the available +/// data for the tensor, in bytes. This size reflects the actual data +/// available, and does not necessarily match what is +/// expected/required for the tensor given its shape and datatype. It +/// is the responsibility of the backend to handle mismatches in these +/// sizes appropriately. +/// \param buffer_count If non-nullptr, returns the number of buffers +/// holding the contents of the tensor. These buffers are accessed +/// using TRITONBACKEND_InputBufferForHostPolicy. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_InputPropertiesForHostPolicy( + TRITONBACKEND_Input* input, const char* host_policy_name, const char** name, + TRITONSERVER_DataType* datatype, const int64_t** shape, + uint32_t* dims_count, uint64_t* byte_size, uint32_t* buffer_count); + +/// Get a buffer holding (part of) the tensor data for an input. For a +/// given input the number of buffers composing the input are found +/// from 'buffer_count' returned by TRITONBACKEND_InputProperties. The +/// returned buffer is owned by the input and so should not be +/// modified or freed by the caller. The lifetime of the buffer +/// matches that of the input and so the buffer should not be accessed +/// after the input tensor object is released. +/// +/// \param input The input tensor. +/// \param index The index of the buffer. Must be 0 <= index < +/// buffer_count, where buffer_count is the value returned by +/// TRITONBACKEND_InputProperties. +/// \param buffer Returns a pointer to a contiguous block of data for +/// the named input. +/// \param buffer_byte_size Returns the size, in bytes, of 'buffer'. +/// \param memory_type Acts as both input and output. On input gives +/// the buffer memory type preferred by the function caller. Returns +/// the actual memory type of 'buffer'. +/// \param memory_type_id Acts as both input and output. On input +/// gives the buffer memory type id preferred by the function caller. +/// Returns the actual memory type id of 'buffer'. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_InputBuffer( + TRITONBACKEND_Input* input, const uint32_t index, const void** buffer, + uint64_t* buffer_byte_size, TRITONSERVER_MemoryType* memory_type, + int64_t* memory_type_id); + +/// Get a buffer holding (part of) the tensor data for an input for a specific +/// host policy. If there are no input buffers specified for this host policy, +/// the fallback input buffer is returned. +/// For a given input the number of buffers composing the input are found +/// from 'buffer_count' returned by TRITONBACKEND_InputPropertiesForHostPolicy. +/// The returned buffer is owned by the input and so should not be modified or +/// freed by the caller. The lifetime of the buffer matches that of the input +/// and so the buffer should not be accessed after the input tensor object is +/// released. +/// +/// \param input The input tensor. +/// \param host_policy_name The host policy name. Fallback input buffer +/// will be return if nullptr is provided. +/// \param index The index of the buffer. Must be 0 <= index < +/// buffer_count, where buffer_count is the value returned by +/// TRITONBACKEND_InputPropertiesForHostPolicy. +/// \param buffer Returns a pointer to a contiguous block of data for +/// the named input. +/// \param buffer_byte_size Returns the size, in bytes, of 'buffer'. +/// \param memory_type Acts as both input and output. On input gives +/// the buffer memory type preferred by the function caller. Returns +/// the actual memory type of 'buffer'. +/// \param memory_type_id Acts as both input and output. On input +/// gives the buffer memory type id preferred by the function caller. +/// Returns the actual memory type id of 'buffer'. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_InputBufferForHostPolicy( + TRITONBACKEND_Input* input, const char* host_policy_name, + const uint32_t index, const void** buffer, uint64_t* buffer_byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id); + +/// Get the buffer attributes associated with the given input buffer. For a +/// given input the number of buffers composing the input are found from +/// 'buffer_count' returned by TRITONBACKEND_InputProperties. The returned +/// 'buffer_attributes' is owned by the input and so should not be modified or +/// freed by the caller. The lifetime of the 'buffer_attributes' matches that of +/// the input and so the 'buffer_attributes' should not be accessed after the +/// input tensor object is released. +/// +/// \param input The input tensor. +/// \param index The index of the buffer. Must be 0 <= index < buffer_count, +/// where buffer_count is the value returned by TRITONBACKEND_InputProperties. +/// \param buffer Returns a pointer to a contiguous block of data for +/// the named input. +/// \param buffer_attributes Returns the attributes for the given buffer. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_InputBufferAttributes( + TRITONBACKEND_Input* input, const uint32_t index, const void** buffer, + TRITONSERVER_BufferAttributes** buffer_attributes); + +/// +/// TRITONBACKEND_Output +/// +/// Object representing a response output tensor. +/// + +/// Get a buffer to use to hold the tensor data for the output. The +/// returned buffer is owned by the output and so should not be freed +/// by the caller. The caller can and should fill the buffer with the +/// output data for the tensor. The lifetime of the buffer matches +/// that of the output and so the buffer should not be accessed after +/// the output tensor object is released. +/// +/// \param buffer Returns a pointer to a buffer where the contents of +/// the output tensor should be placed. +/// \param buffer_byte_size The size, in bytes, of the buffer required +/// by the caller. +/// \param memory_type Acts as both input and output. On input gives +/// the buffer memory type preferred by the caller. Returns the +/// actual memory type of 'buffer'. +/// \param memory_type_id Acts as both input and output. On input +/// gives the buffer memory type id preferred by the caller. Returns +/// the actual memory type id of 'buffer'. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_OutputBuffer( + TRITONBACKEND_Output* output, void** buffer, + const uint64_t buffer_byte_size, TRITONSERVER_MemoryType* memory_type, + int64_t* memory_type_id); + +/// Get the buffer attributes associated with the given output buffer. The +/// returned 'buffer_attributes' is owned by the output and so should not be +/// modified or freed by the caller. The lifetime of the 'buffer_attributes' +/// matches that of the output and so the 'buffer_attributes' should not be +/// accessed after the output tensor object is released. This function must be +/// called after the TRITONBACKEND_OutputBuffer otherwise it might contain +/// incorrect data. +/// +/// \param output The output tensor. +/// \param buffer_attributes Returns the attributes for the output buffer. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_OutputBufferAttributes( + TRITONBACKEND_Output* output, + TRITONSERVER_BufferAttributes** buffer_attributes); + +/// +/// TRITONBACKEND_Request +/// +/// Object representing an inference request. +/// + +/// Get the ID of the request. Can be nullptr if request doesn't have +/// an ID. The returned string is owned by the request, not the +/// caller, and so should not be modified or freed. +/// +/// \param request The inference request. +/// \param id Returns the ID. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_RequestId( + TRITONBACKEND_Request* request, const char** id); + +/// Get the correlation ID of the request if it is an unsigned integer. +/// Zero indicates that the request does not have a correlation ID. +/// Returns failure if correlation ID for given request is not an unsigned +/// integer. +/// +/// \param request The inference request. +/// \param id Returns the correlation ID. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_RequestCorrelationId( + TRITONBACKEND_Request* request, uint64_t* id); + +/// Get the correlation ID of the request if it is a string. +/// Empty string indicates that the request does not have a correlation ID. +/// Returns error if correlation ID for given request is not a string. +/// +/// \param request The inference request. +/// \param id Returns the correlation ID. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_RequestCorrelationIdString( + TRITONBACKEND_Request* request, const char** id); + +/// Get the flag(s) associated with a request. On return 'flags' holds +/// a bitwise-or of all flag values, see TRITONSERVER_RequestFlag for +/// available flags. +/// +/// \param request The inference request. +/// \param flags Returns the flags. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_RequestFlags( + TRITONBACKEND_Request* request, uint32_t* flags); + +/// Get the number of input tensors specified in the request. +/// +/// \param request The inference request. +/// \param count Returns the number of input tensors. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_RequestInputCount( + TRITONBACKEND_Request* request, uint32_t* count); + +/// Get the name of an input tensor. The caller does not own +/// the returned string and must not modify or delete it. The lifetime +/// of the returned string extends only as long as 'request'. +/// +/// \param request The inference request. +/// \param index The index of the input tensor. Must be 0 <= index < +/// count, where count is the value returned by +/// TRITONBACKEND_RequestInputCount. +/// \param input_name Returns the name of the input tensor +/// corresponding to the index. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_RequestInputName( + TRITONBACKEND_Request* request, const uint32_t index, + const char** input_name); + +/// Get a named request input. The lifetime of the returned input +/// object matches that of the request and so the input object should +/// not be accessed after the request object is released. +/// +/// \param request The inference request. +/// \param name The name of the input. +/// \param input Returns the input corresponding to the name. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_RequestInput( + TRITONBACKEND_Request* request, const char* name, + TRITONBACKEND_Input** input); + +/// Get a request input by index. The order of inputs in a given +/// request is not necessarily consistent with other requests, even if +/// the requests are in the same batch. As a result, you can not +/// assume that an index obtained from one request will point to the +/// same input in a different request. +/// +/// The lifetime of the returned input object matches that of the +/// request and so the input object should not be accessed after the +/// request object is released. +/// +/// \param request The inference request. +/// \param index The index of the input tensor. Must be 0 <= index < +/// count, where count is the value returned by +/// TRITONBACKEND_RequestInputCount. +/// \param input Returns the input corresponding to the index. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_RequestInputByIndex( + TRITONBACKEND_Request* request, const uint32_t index, + TRITONBACKEND_Input** input); + +/// Get the number of output tensors requested to be returned in the +/// request. +/// +/// \param request The inference request. +/// \param count Returns the number of output tensors. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_RequestOutputCount( + TRITONBACKEND_Request* request, uint32_t* count); + +/// Get the name of a requested output tensor. The caller does not own +/// the returned string and must not modify or delete it. The lifetime +/// of the returned string extends only as long as 'request'. +/// +/// \param request The inference request. +/// \param index The index of the requested output tensor. Must be 0 +/// <= index < count, where count is the value returned by +/// TRITONBACKEND_RequestOutputCount. +/// \param output_name Returns the name of the requested output tensor +/// corresponding to the index. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_RequestOutputName( + TRITONBACKEND_Request* request, const uint32_t index, + const char** output_name); + +/// Returns the preferred memory type and memory type ID of the output buffer +/// for the request. As much as possible, Triton will attempt to return +/// the same memory_type and memory_type_id values that will be returned by +/// the subsequent call to TRITONBACKEND_OutputBuffer, however, the backend must +/// be capable of handling cases where the values differ. +/// +/// \param request The request. +/// \param name The name of the output tensor. This is optional +/// and it should be set to nullptr to indicate that the tensor name has +/// not determined. +/// \param byte_size The expected size of the buffer. This is optional +/// and it should be set to nullptr to indicate that the byte size has +/// not determined. +/// \param memory_type Acts as both input and output. On input gives +/// the memory type preferred by the caller. Returns memory type preferred +/// by Triton, taken account of the caller preferred type. +/// \param memory_type_id Acts as both input and output. On input gives +/// the memory type ID preferred by the caller. Returns memory type ID preferred +/// by Triton, taken account of the caller preferred type ID. +/// \return a TRITONSERVER_Error object if a failure occurs. +/// A TRITONSERVER_ERROR_UNAVAILABLE error indicates that the properties are not +/// available, other error codes indicate an error. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_RequestOutputBufferProperties( + TRITONBACKEND_Request* request, const char* name, size_t* byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id); + +/// Release the request. The request should be released when it is no +/// longer needed by the backend. If this call returns with an error +/// (i.e. non-nullptr) then the request was not released and ownership +/// remains with the backend. If this call returns with success, the +/// 'request' object is no longer owned by the backend and must not be +/// used. Any tensor names, data types, shapes, input tensors, +/// etc. returned by TRITONBACKEND_Request* functions for this request +/// are no longer valid. If a persistent copy of that data is required +/// it must be created before calling this function. +/// +/// \param request The inference request. +/// \param release_flags Flags indicating what type of request release +/// should be performed. \see TRITONSERVER_RequestReleaseFlag. \see +/// TRITONSERVER_InferenceRequestReleaseFn_t. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_RequestRelease( + TRITONBACKEND_Request* request, uint32_t release_flags); + +/// +/// TRITONBACKEND_ResponseFactory +/// +/// Object representing an inference response factory. Using a +/// response factory is not required; instead a response can be +/// generated directly from a TRITONBACKEND_Request object using +/// TRITONBACKEND_ResponseNew(). A response factory allows a request +/// to be released before all responses have been sent. Releasing a +/// request as early as possible releases all input tensor data and +/// therefore may be desirable in some cases. + +/// Create the response factory associated with a request. +/// +/// \param factory Returns the new response factory. +/// \param request The inference request. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ResponseFactoryNew( + TRITONBACKEND_ResponseFactory** factory, TRITONBACKEND_Request* request); + +/// Destroy a response factory. +/// +/// \param factory The response factory. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ResponseFactoryDelete( + TRITONBACKEND_ResponseFactory* factory); + +/// Send response flags without a corresponding response. +/// +/// \param factory The response factory. +/// \param send_flags Flags to send. \see +/// TRITONSERVER_ResponseCompleteFlag. \see +/// TRITONSERVER_InferenceResponseCompleteFn_t. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ResponseFactorySendFlags( + TRITONBACKEND_ResponseFactory* factory, const uint32_t send_flags); + +/// +/// TRITONBACKEND_Response +/// +/// Object representing an inference response. For a given request, +/// the backend must carefully manage the lifecycle of responses +/// generated for that request to ensure that the output tensor +/// buffers are allocated correctly. When a response is created with +/// TRITONBACKEND_ResponseNew or TRITONBACKEND_ResponseNewFromFactory, +/// all the outputs and corresponding buffers must be created for that +/// response using TRITONBACKEND_ResponseOutput and +/// TRITONBACKEND_OutputBuffer *before* another response is created +/// for the request. For a given response, outputs can be created in +/// any order but they must be created sequentially/sychronously (for +/// example, the backend cannot use multiple threads to simultaneously +/// add multiple outputs to a response). +/// +/// The above requirement applies only to responses being generated +/// for a given request. The backend may generate responses in +/// parallel on multiple threads as long as those responses are for +/// different requests. +/// +/// This order of response creation must be strictly followed. But, +/// once response(s) are created they do not need to be sent +/// immediately, nor do they need to be sent in the order they were +/// created. The backend may even delete a created response instead of +/// sending it by using TRITONBACKEND_ResponseDelete. + +/// Create a response for a request. +/// +/// \param response Returns the new response. +/// \param request The request. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ResponseNew( + TRITONBACKEND_Response** response, TRITONBACKEND_Request* request); + +/// Create a response using a factory. +/// +/// \param response Returns the new response. +/// \param factory The response factory. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ResponseNewFromFactory( + TRITONBACKEND_Response** response, TRITONBACKEND_ResponseFactory* factory); + +/// Destroy a response. It is not necessary to delete a response if +/// TRITONBACKEND_ResponseSend is called as that function transfers +/// ownership of the response object to Triton. +/// +/// \param response The response. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ResponseDelete( + TRITONBACKEND_Response* response); + +/// Set a string parameter in the response. +/// +/// \param response The response. +/// \param name The name of the parameter. +/// \param value The value of the parameter. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ResponseSetStringParameter( + TRITONBACKEND_Response* response, const char* name, const char* value); + +/// Set an integer parameter in the response. +/// +/// \param response The response. +/// \param name The name of the parameter. +/// \param value The value of the parameter. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ResponseSetIntParameter( + TRITONBACKEND_Response* response, const char* name, const int64_t value); + +/// Set an boolean parameter in the response. +/// +/// \param response The response. +/// \param name The name of the parameter. +/// \param value The value of the parameter. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ResponseSetBoolParameter( + TRITONBACKEND_Response* response, const char* name, const bool value); + +/// Create an output tensor in the response. The lifetime of the +/// returned output tensor object matches that of the response and so +/// the output tensor object should not be accessed after the response +/// object is deleted. +/// +/// \param response The response. +/// \param output Returns the new response output. +/// \param name The name of the output tensor. +/// \param datatype The datatype of the output tensor. +/// \param shape The shape of the output tensor. +/// \param dims_count The number of dimensions in the output tensor +/// shape. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ResponseOutput( + TRITONBACKEND_Response* response, TRITONBACKEND_Output** output, + const char* name, const TRITONSERVER_DataType datatype, + const int64_t* shape, const uint32_t dims_count); + +/// Send a response. Calling this function transfers ownership of the +/// response object to Triton. The caller must not access or delete +/// the response object after calling this function. +/// +/// \param response The response. +/// \param send_flags Flags associated with the response. \see +/// TRITONSERVER_ResponseCompleteFlag. \see +/// TRITONSERVER_InferenceResponseCompleteFn_t. +/// \param error The TRITONSERVER_Error to send if the response is an +/// error, or nullptr if the response is successful. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ResponseSend( + TRITONBACKEND_Response* response, const uint32_t send_flags, + TRITONSERVER_Error* error); + +/// +/// TRITONBACKEND_State +/// +/// Object representing a state. +/// + +/// Create a state in the request. The returned state object is only valid +/// before the TRITONBACKEND_StateUpdate is called. The state should not be +/// freed by the caller. If TRITONBACKEND_StateUpdate is not called, the +/// lifetime of the state matches the lifetime of the request. If the state name +/// does not exist in the "state" section of the model configuration, the state +/// will not be created and an error will be returned. If this function is +/// called when sequence batching is not enabled or there is no 'states' section +/// in the sequence batching section of the model configuration, this call will +/// return an error. +/// +/// \param state Returns the new state. +/// \param request The request. +/// \param name The name of the state. +/// \param datatype The datatype of the state. +/// \param shape The shape of the state. +/// \param dims_count The number of dimensions in the state shape. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_StateNew( + TRITONBACKEND_State** state, TRITONBACKEND_Request* request, + const char* name, const TRITONSERVER_DataType datatype, + const int64_t* shape, const uint32_t dims_count); + +/// Update the state for the sequence. Calling this function will replace the +/// state stored for this seqeunce in Triton with 'state' provided in the +/// function argument. If this function is called when sequence batching is not +/// enabled or there is no 'states' section in the sequence batching section of +/// the model configuration, this call will return an error. The backend is not +/// required to call this function. If the backend doesn't call +/// TRITONBACKEND_StateUpdate function, this particular state for the sequence +/// will not be updated and the next inference request in the sequence will use +/// the same state as the current inference request. +/// +/// \param state The state. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_StateUpdate( + TRITONBACKEND_State* state); + +/// Get a buffer to use to hold the tensor data for the state. The returned +/// buffer is owned by the state and so should not be freed by the caller. The +/// caller can and should fill the buffer with the state data. The buffer must +/// not be accessed by the backend after TRITONBACKEND_StateUpdate is called. +/// The caller should fill the buffer before calling TRITONBACKEND_StateUpdate. +/// +/// \param state The state. +/// \param buffer Returns a pointer to a buffer where the contents of the state +/// should be placed. +/// \param buffer_byte_size The size, in bytes, of the buffer required +/// by the caller. +/// \param memory_type Acts as both input and output. On input gives +/// the buffer memory type preferred by the caller. Returns the +/// actual memory type of 'buffer'. +/// \param memory_type_id Acts as both input and output. On input +/// gives the buffer memory type id preferred by the caller. Returns +/// the actual memory type id of 'buffer'. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_StateBuffer( + TRITONBACKEND_State* state, void** buffer, const uint64_t buffer_byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id); + +/// Get the buffer attributes associated with the given state buffer. +/// The returned 'buffer_attributes' is owned by the state and so should not be +/// modified or freed by the caller. The lifetime of the 'buffer_attributes' +/// matches that of the state. +/// +/// \param state The state. +/// \param buffer_attributes Returns the buffer attributes for the given state. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_StateBufferAttributes( + TRITONBACKEND_State* state, + TRITONSERVER_BufferAttributes** buffer_attributes); + +/// +/// TRITONBACKEND_Backend +/// +/// Object representing a backend. +/// + +/// TRITONBACKEND_ExecutionPolicy +/// +/// Types of execution policy that can be implemented by a backend. +/// +/// TRITONBACKEND_EXECUTION_BLOCKING: An instance of the model +/// blocks in TRITONBACKEND_ModelInstanceExecute until it is ready +/// to handle another inference. Upon returning from +/// TRITONBACKEND_ModelInstanceExecute, Triton may immediately +/// call TRITONBACKEND_ModelInstanceExecute for the same instance +/// to execute a new batch of requests. Thus, most backends using +/// this policy will not return from +/// TRITONBACKEND_ModelInstanceExecute until all responses have +/// been sent and all requests have been released. This is the +/// default execution policy. +/// +/// TRITONBACKEND_EXECUTION_DEVICE_BLOCKING: An instance, A, of the +/// model blocks in TRITONBACKEND_ModelInstanceExecute if the +/// device associated with the instance is unable to handle +/// another inference. Even if another instance, B, associated +/// with the device, is available and ready to perform an +/// inference, Triton will not invoke +/// TRITONBACKEND_ModeInstanceExecute for B until A returns from +/// TRITONBACKEND_ModelInstanceExecute. Triton will not be blocked +/// from calling TRITONBACKEND_ModelInstanceExecute for instance +/// C, which is associated with a different device than A and B, +/// even if A or B has not returned from +/// TRITONBACKEND_ModelInstanceExecute. This execution policy is +/// typically used by a backend that can cooperatively execute +/// multiple model instances on the same device. +/// +typedef enum TRITONBACKEND_execpolicy_enum { + TRITONBACKEND_EXECUTION_BLOCKING, + TRITONBACKEND_EXECUTION_DEVICE_BLOCKING +} TRITONBACKEND_ExecutionPolicy; + +/// Get the name of the backend. The caller does not own the returned +/// string and must not modify or delete it. The lifetime of the +/// returned string extends only as long as 'backend'. +/// +/// \param backend The backend. +/// \param name Returns the name of the backend. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_BackendName( + TRITONBACKEND_Backend* backend, const char** name); + +/// Get the backend configuration. The 'backend_config' message is +/// owned by Triton and should not be modified or freed by the caller. +/// +/// The backend configuration, as JSON, is: +/// +/// { +/// "cmdline" : { +/// "" : "", +/// ... +/// } +/// } +/// +/// \param backend The backend. +/// \param backend_config Returns the backend configuration as a message. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_BackendConfig( + TRITONBACKEND_Backend* backend, TRITONSERVER_Message** backend_config); + +/// Get the execution policy for this backend. By default the +/// execution policy is TRITONBACKEND_EXECUTION_BLOCKING. +/// +/// \param backend The backend. +/// \param policy Returns the execution policy. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_BackendExecutionPolicy( + TRITONBACKEND_Backend* backend, TRITONBACKEND_ExecutionPolicy* policy); + +/// Set the execution policy for this backend. By default the +/// execution policy is TRITONBACKEND_EXECUTION_BLOCKING. Triton reads +/// the backend's execution policy after calling +/// TRITONBACKEND_Initialize, so to be recognized changes to the +/// execution policy must be made in TRITONBACKEND_Initialize. +/// Also, note that if using sequence batcher for the model, Triton will +/// use TRITONBACKEND_EXECUTION_BLOCKING policy irrespective of the +/// policy specified by this setter function. +/// +/// \param backend The backend. +/// \param policy The execution policy. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_BackendSetExecutionPolicy( + TRITONBACKEND_Backend* backend, TRITONBACKEND_ExecutionPolicy policy); + +/// Get the location of the files that make up the backend +/// implementation. This location contains the backend shared library +/// and any other files located with the shared library. The +/// 'location' communicated depends on how the backend is being +/// communicated to Triton as indicated by 'artifact_type'. +/// +/// TRITONBACKEND_ARTIFACT_FILESYSTEM: The backend artifacts are +/// made available to Triton via the local filesytem. 'location' +/// returns the full path to the directory containing this +/// backend's artifacts. The returned string is owned by Triton, +/// not the caller, and so should not be modified or freed. +/// +/// \param backend The backend. +/// \param artifact_type Returns the artifact type for the backend. +/// \param path Returns the location. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_BackendArtifacts( + TRITONBACKEND_Backend* backend, TRITONBACKEND_ArtifactType* artifact_type, + const char** location); + +/// Get the memory manager associated with a backend. +/// +/// \param backend The backend. +/// \param manager Returns the memory manager. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_BackendMemoryManager( + TRITONBACKEND_Backend* backend, TRITONBACKEND_MemoryManager** manager); + +/// Get the user-specified state associated with the backend. The +/// state is completely owned and managed by the backend. +/// +/// \param backend The backend. +/// \param state Returns the user state, or nullptr if no user state. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_BackendState( + TRITONBACKEND_Backend* backend, void** state); + +/// Set the user-specified state associated with the backend. The +/// state is completely owned and managed by the backend. +/// +/// \param backend The backend. +/// \param state The user state, or nullptr if no user state. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_BackendSetState( + TRITONBACKEND_Backend* backend, void* state); + +/// +/// TRITONBACKEND_Model +/// +/// Object representing a model implemented using the backend. +/// + +/// Get the name of the model. The returned string is owned by the +/// model object, not the caller, and so should not be modified or +/// freed. +/// +/// \param model The model. +/// \param name Returns the model name. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ModelName( + TRITONBACKEND_Model* model, const char** name); + +/// Get the version of the model. +/// +/// \param model The model. +/// \param version Returns the model version. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ModelVersion( + TRITONBACKEND_Model* model, uint64_t* version); + +/// Get the location of the files that make up the model. The +/// 'location' communicated depends on how the model is being +/// communicated to Triton as indicated by 'artifact_type'. +/// +/// TRITONBACKEND_ARTIFACT_FILESYSTEM: The model artifacts are made +/// available to Triton via the local filesytem. 'location' +/// returns the full path to the directory in the model repository +/// that contains this model's artifacts. The returned string is +/// owned by Triton, not the caller, and so should not be modified +/// or freed. +/// +/// \param model The model. +/// \param artifact_type Returns the artifact type for the model. +/// \param path Returns the location. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ModelRepository( + TRITONBACKEND_Model* model, TRITONBACKEND_ArtifactType* artifact_type, + const char** location); + +/// Get the model configuration. The caller takes ownership of the +/// message object and must call TRITONSERVER_MessageDelete to release +/// the object. The configuration is available via this call even +/// before the model is loaded and so can be used in +/// TRITONBACKEND_ModelInitialize. TRITONSERVER_ServerModelConfig +/// returns equivalent information but is not useable until after the +/// model loads. +/// +/// \param model The model. +/// \param config_version The model configuration will be returned in +/// a format matching this version. If the configuration cannot be +/// represented in the requested version's format then an error will +/// be returned. Currently only version 1 is supported. +/// \param model_config Returns the model configuration as a message. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ModelConfig( + TRITONBACKEND_Model* model, const uint32_t config_version, + TRITONSERVER_Message** model_config); + +/// Whether the backend should attempt to auto-complete the model configuration. +/// If true, the model should fill the inputs, outputs, and max batch size in +/// the model configuration if incomplete. If the model configuration is +/// changed, the new configuration must be reported to Triton using +/// TRITONBACKEND_ModelSetConfig. +/// +/// \param model The model. +/// \param auto_complete_config Returns whether the backend should auto-complete +/// the model configuration. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelAutoCompleteConfig( + TRITONBACKEND_Model* model, bool* auto_complete_config); + +/// Set the model configuration in Triton server. This API should only be called +/// when the backend implements the auto-completion of model configuration +/// and TRITONBACKEND_ModelAutoCompleteConfig returns true in +/// auto_complete_config. Only the inputs, outputs, max batch size, and +/// scheduling choice can be changed. A caveat being scheduling choice can only +/// be changed if none is previously set. Any other changes to the model +/// configuration will be ignored by Triton. This function can only be called +/// from TRITONBACKEND_ModelInitialize, calling in any other context will result +/// in an error being returned. Additionally, Triton server can add some of the +/// missing fields in the provided config with this call. The backend must get +/// the complete configuration again by using TRITONBACKEND_ModelConfig. +/// TRITONBACKEND_ModelSetConfig does not take ownership of the message object +/// and so the caller should call TRITONSERVER_MessageDelete to release the +/// object once the function returns. +/// +/// \param model The model. +/// \param config_version The format version of the model configuration. +/// If the configuration is not represented in the version's format +/// then an error will be returned. Currently only version 1 is supported. +/// \param model_config The updated model configuration as a message. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ModelSetConfig( + TRITONBACKEND_Model* model, const uint32_t config_version, + TRITONSERVER_Message* model_config); + +/// Get the TRITONSERVER_Server object that this model is being served +/// by. +/// +/// \param model The model. +/// \param server Returns the server. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ModelServer( + TRITONBACKEND_Model* model, TRITONSERVER_Server** server); + +/// Get the backend used by the model. +/// +/// \param model The model. +/// \param model Returns the backend object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ModelBackend( + TRITONBACKEND_Model* model, TRITONBACKEND_Backend** backend); + +/// Get the user-specified state associated with the model. The +/// state is completely owned and managed by the backend. +/// +/// \param model The model. +/// \param state Returns the user state, or nullptr if no user state. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ModelState( + TRITONBACKEND_Model* model, void** state); + +/// Set the user-specified state associated with the model. The +/// state is completely owned and managed by the backend. +/// +/// \param model The model. +/// \param state The user state, or nullptr if no user state. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ModelSetState( + TRITONBACKEND_Model* model, void* state); + +/// +/// TRITONBACKEND_ModelInstance +/// +/// Object representing a model instance implemented using the +/// backend. +/// + +/// Get the name of the model instance. The returned string is owned by the +/// model object, not the caller, and so should not be modified or +/// freed. +/// +/// \param instance The model instance. +/// \param name Returns the instance name. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ModelInstanceName( + TRITONBACKEND_ModelInstance* instance, const char** name); + +/// Get the kind of the model instance. +/// +/// \param instance The model instance. +/// \param kind Returns the instance kind. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ModelInstanceKind( + TRITONBACKEND_ModelInstance* instance, + TRITONSERVER_InstanceGroupKind* kind); + +/// Get the device ID of the model instance. +/// +/// \param instance The model instance. +/// \param device_id Returns the instance device ID. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ModelInstanceDeviceId( + TRITONBACKEND_ModelInstance* instance, int32_t* device_id); + +/// Get the host policy setting. The 'host_policy' message is +/// owned by Triton and should not be modified or freed by the caller. +/// +/// The host policy setting, as JSON, is: +/// +/// { +/// "" : { +/// "" : "", +/// ... +/// } +/// } +/// +/// \param instance The model instance. +/// \param host_policy Returns the host policy setting as a message. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceHostPolicy( + TRITONBACKEND_ModelInstance* instance, TRITONSERVER_Message** host_policy); + +/// Whether the model instance is passive. +/// +/// \param instance The model instance. +/// \param is_passive Returns true if the instance is passive, false otherwise +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ModelInstanceIsPassive( + TRITONBACKEND_ModelInstance* instance, bool* is_passive); + +/// Get the number of optimization profiles to be loaded for the instance. +/// +/// \param instance The model instance. +/// \param count Returns the number of optimization profiles. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceProfileCount( + TRITONBACKEND_ModelInstance* instance, uint32_t* count); + +/// Get the name of optimization profile. The caller does not own +/// the returned string and must not modify or delete it. The lifetime +/// of the returned string extends only as long as 'instance'. +/// +/// \param instance The model instance. +/// \param index The index of the optimization profile. Must be 0 +/// <= index < count, where count is the value returned by +/// TRITONBACKEND_ModelInstanceProfileCount. +/// \param profile_name Returns the name of the optimization profile +/// corresponding to the index. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceProfileName( + TRITONBACKEND_ModelInstance* instance, const uint32_t index, + const char** profile_name); + +/// Get the number of secondary devices configured for the instance. +/// +/// \param instance The model instance. +/// \param count Returns the number of secondary devices. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceSecondaryDeviceCount( + TRITONBACKEND_ModelInstance* instance, uint32_t* count); + +/// Get the properties of indexed secondary device. The returned +/// strings and other properties are owned by the instance, not the +/// caller, and so should not be modified or freed. +/// +/// \param instance The model instance. +/// \param index The index of the secondary device. Must be 0 +/// <= index < count, where count is the value returned by +/// TRITONBACKEND_ModelInstanceSecondaryDeviceCount. +/// \param kind Returns the kind of secondary device corresponding +/// to the index. +/// \param id Returns the id of secondary device corresponding to the index. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceSecondaryDeviceProperties( + TRITONBACKEND_ModelInstance* instance, uint32_t index, const char** kind, + int64_t* id); + +/// Get the model associated with a model instance. +/// +/// \param instance The model instance. +/// \param backend Returns the model object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ModelInstanceModel( + TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Model** model); + +/// Get the user-specified state associated with the model +/// instance. The state is completely owned and managed by the +/// backend. +/// +/// \param instance The model instance. +/// \param state Returns the user state, or nullptr if no user state. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ModelInstanceState( + TRITONBACKEND_ModelInstance* instance, void** state); + +/// Set the user-specified state associated with the model +/// instance. The state is completely owned and managed by the +/// backend. +/// +/// \param instance The model instance. +/// \param state The user state, or nullptr if no user state. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_ModelInstanceSetState( + TRITONBACKEND_ModelInstance* instance, void* state); + +/// Record statistics for an inference request. +/// +/// Set 'success' true to indicate that the inference request +/// completed successfully. In this case all timestamps should be +/// non-zero values reported in nanoseconds and should be collected +/// using std::chrono::steady_clock::now().time_since_epoch() or the equivalent. +/// Set 'success' to false to indicate that the inference request failed +/// to complete successfully. In this case all timestamps values are +/// ignored. +/// +/// For consistency of measurement across different backends, the +/// timestamps should be collected at the following points during +/// TRITONBACKEND_ModelInstanceExecute. +/// +/// TRITONBACKEND_ModelInstanceExecute() +/// CAPTURE TIMESPACE (exec_start_ns) +/// < process input tensors to prepare them for inference +/// execution, including copying the tensors to/from GPU if +/// necessary> +/// CAPTURE TIMESPACE (compute_start_ns) +/// < perform inference computations to produce outputs > +/// CAPTURE TIMESPACE (compute_end_ns) +/// < allocate output buffers and extract output tensors, including +/// copying the tensors to/from GPU if necessary> +/// CAPTURE TIMESPACE (exec_end_ns) +/// return +/// +/// Note that these statistics are associated with a valid +/// TRITONBACKEND_Request object and so must be reported before the +/// request is released. For backends that release the request before +/// all response(s) are sent, these statistics cannot capture +/// information about the time required to produce the response. +/// +/// \param instance The model instance. +/// \param request The inference request that statistics are being +/// reported for. +/// \param success True if the inference request completed +/// successfully, false if it failed to complete. +/// \param exec_start_ns Timestamp for the start of execution. +/// \param compute_start_ns Timestamp for the start of execution +/// computations. +/// \param compute_end_ns Timestamp for the end of execution +/// computations. +/// \param exec_end_ns Timestamp for the end of execution. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceReportStatistics( + TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request* request, + const bool success, const uint64_t exec_start_ns, + const uint64_t compute_start_ns, const uint64_t compute_end_ns, + const uint64_t exec_end_ns); + +/// Record statistics for the execution of an entire batch of +/// inference requests. +/// +/// All timestamps should be non-zero values reported in nanoseconds +/// and should be collected using +/// std::chrono::steady_clock::now().time_since_epoch() or the equivalent. +/// See TRITONBACKEND_ModelInstanceReportStatistics for more information about +/// the timestamps. +/// +/// 'batch_size' is the sum of the batch sizes for the individual +/// requests that were delivered together in the call to +/// TRITONBACKEND_ModelInstanceExecute. For example, if three requests +/// are passed to TRITONBACKEND_ModelInstanceExecute and those +/// requests have batch size 1, 2, and 3; then 'batch_size' should be +/// set to 6. +/// +/// \param instance The model instance. +/// \param batch_size Combined batch size of all the individual +/// requests executed in the batch. +/// \param exec_start_ns Timestamp for the start of execution. +/// \param compute_start_ns Timestamp for the start of execution +/// computations. +/// \param compute_end_ns Timestamp for the end of execution +/// computations. +/// \param exec_end_ns Timestamp for the end of execution. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceReportBatchStatistics( + TRITONBACKEND_ModelInstance* instance, const uint64_t batch_size, + const uint64_t exec_start_ns, const uint64_t compute_start_ns, + const uint64_t compute_end_ns, const uint64_t exec_end_ns); + +/// +/// The following functions can be implemented by a backend. Functions +/// indicated as required must be implemented or the backend will fail +/// to load. +/// + +/// Initialize a backend. This function is optional, a backend is not +/// required to implement it. This function is called once when a +/// backend is loaded to allow the backend to initialize any state +/// associated with the backend. A backend has a single state that is +/// shared across all models that use the backend. +/// +/// \param backend The backend. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_Initialize( + TRITONBACKEND_Backend* backend); + +/// Finalize for a backend. This function is optional, a backend is +/// not required to implement it. This function is called once, just +/// before the backend is unloaded. All state associated with the +/// backend should be freed and any threads created for the backend +/// should be exited/joined before returning from this function. +/// +/// \param backend The backend. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_Finalize( + TRITONBACKEND_Backend* backend); + +/// Initialize for a model. This function is optional, a backend is +/// not required to implement it. This function is called once when a +/// model that uses the backend is loaded to allow the backend to +/// initialize any state associated with the model. The backend should +/// also examine the model configuration to determine if the +/// configuration is suitable for the backend. Any errors reported by +/// this function will prevent the model from loading. +/// +/// \param model The model. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelInitialize( + TRITONBACKEND_Model* model); + +/// Finalize for a model. This function is optional, a backend is not +/// required to implement it. This function is called once for a +/// model, just before the model is unloaded from Triton. All state +/// associated with the model should be freed and any threads created +/// for the model should be exited/joined before returning from this +/// function. +/// +/// \param model The model. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelFinalize( + TRITONBACKEND_Model* model); + +/// Initialize for a model instance. This function is optional, a +/// backend is not required to implement it. This function is called +/// once when a model instance is created to allow the backend to +/// initialize any state associated with the instance. +/// +/// \param instance The model instance. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelInstanceInitialize( + TRITONBACKEND_ModelInstance* instance); + +/// Finalize for a model instance. This function is optional, a +/// backend is not required to implement it. This function is called +/// once for an instance, just before the corresponding model is +/// unloaded from Triton. All state associated with the instance +/// should be freed and any threads created for the instance should be +/// exited/joined before returning from this function. +/// +/// \param instance The model instance. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelInstanceFinalize( + TRITONBACKEND_ModelInstance* instance); + +/// Execute a batch of one or more requests on a model instance. This +/// function is required. Triton will not perform multiple +/// simultaneous calls to this function for a given model 'instance'; +/// however, there may be simultaneous calls for different model +/// instances (for the same or different models). +/// +/// If an error is returned the ownership of the request objects +/// remains with Triton and the backend must not retain references to +/// the request objects or access them in any way. +/// +/// If success is returned, ownership of the request objects is +/// transferred to the backend and it is then responsible for creating +/// responses and releasing the request objects. Note that even though +/// ownership of the request objects is transferred to the backend, the +/// ownership of the buffer holding request pointers is returned back +/// to Triton upon return from TRITONBACKEND_ModelInstanceExecute. If +/// any request objects need to be maintained beyond +/// TRITONBACKEND_ModelInstanceExecute, then the pointers must be copied +/// out of the array within TRITONBACKEND_ModelInstanceExecute. +/// +/// \param instance The model instance. +/// \param requests The requests. +/// \param request_count The number of requests in the batch. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelInstanceExecute( + TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests, + const uint32_t request_count); + +/// Query the backend for different model attributes. This function is optional, +/// a backend is not required to implement it. The backend is also not required +/// to set all backend attribute listed. This function is called when +/// Triton requires further backend / model information to perform operations. +/// This function may be called multiple times within the lifetime of the +/// backend (between TRITONBACKEND_Initialize and TRITONBACKEND_Finalize). +/// The backend may return error to indicate failure to set the backend +/// attributes, and the attributes specified in the same function call will be +/// ignored. Triton will update the specified attributes if 'nullptr' is +/// returned. +/// +/// \param backend The backend. +/// \param backend_attributes Return the backend attribute. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_GetBackendAttribute( + TRITONBACKEND_Backend* backend, + TRITONBACKEND_BackendAttribute* backend_attributes); + +/// TRITONBACKEND_BackendAttribute +/// +/// API to modify attributes associated with a backend. +/// + +/// Add the preferred instance group of the backend. This function +/// can be called multiple times to cover different instance group kinds that +/// the backend supports, given the priority order that the first call describes +/// the most preferred group. In the case where instance group are not +/// explicitly provided, Triton will use this attribute to create model +/// deployment that aligns more with the backend preference. +/// +/// \param backend_attributes The backend attributes object. +/// \param kind The kind of the instance group. +/// \param count The number of instances per device. Triton default will be used +/// if 0 is provided. +/// \param device_ids The devices where instances should be available. Triton +/// default will be used if 'nullptr' is provided. +/// \param id_count The number of devices in 'device_ids'. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_BackendAttributeAddPreferredInstanceGroup( + TRITONBACKEND_BackendAttribute* backend_attributes, + const TRITONSERVER_InstanceGroupKind kind, const uint64_t count, + const uint64_t* device_ids, const uint64_t id_count); + +#ifdef __cplusplus +} +#endif diff --git a/3rdparty/core-r22.12/include/triton/core/tritonrepoagent.h b/3rdparty/core-r22.12/include/triton/core/tritonrepoagent.h new file mode 100644 index 0000000000000000000000000000000000000000..078ec6219c8ac05dfec5baa44312003a82bbcdd6 --- /dev/null +++ b/3rdparty/core-r22.12/include/triton/core/tritonrepoagent.h @@ -0,0 +1,417 @@ +// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include "triton/core/tritonserver.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef _COMPILING_TRITONREPOAGENT +#if defined(_MSC_VER) +#define TRITONREPOAGENT_DECLSPEC __declspec(dllexport) +#define TRITONREPOAGENT_ISPEC __declspec(dllimport) +#elif defined(__GNUC__) +#define TRITONREPOAGENT_DECLSPEC __attribute__((__visibility__("default"))) +#define TRITONREPOAGENT_ISPEC +#else +#define TRITONREPOAGENT_DECLSPEC +#define TRITONREPOAGENT_ISPEC +#endif +#else +#if defined(_MSC_VER) +#define TRITONREPOAGENT_DECLSPEC __declspec(dllimport) +#define TRITONREPOAGENT_ISPEC __declspec(dllexport) +#else +#define TRITONREPOAGENT_DECLSPEC +#define TRITONREPOAGENT_ISPEC +#endif +#endif + +struct TRITONREPOAGENT_Agent; +struct TRITONREPOAGENT_AgentModel; + +/// +/// TRITONREPOAGENT API Version +/// +/// The TRITONREPOAGENT API is versioned with major and minor version +/// numbers. Any change to the API that does not impact backwards +/// compatibility (for example, adding a non-required function) +/// increases the minor version number. Any change that breaks +/// backwards compatibility (for example, deleting or changing the +/// behavior of a function) increases the major version number. A +/// repository agent should check that the API version used to compile +/// the agent is compatible with the API version of the Triton server +/// that it is running in. This is typically done by code similar to +/// the following which makes sure that the major versions are equal +/// and that the minor version of Triton is >= the minor version used +/// to build the agent. +/// +/// uint32_t api_version_major, api_version_minor; +/// TRITONREPOAGENT_ApiVersion(&api_version_major, &api_version_minor); +/// if ((api_version_major != TRITONREPOAGENT_API_VERSION_MAJOR) || +/// (api_version_minor < TRITONREPOAGENT_API_VERSION_MINOR)) { +/// return TRITONSERVER_ErrorNew( +/// TRITONSERVER_ERROR_UNSUPPORTED, +/// "triton repository agent API version does not support this agent"); +/// } +/// +#define TRITONREPOAGENT_API_VERSION_MAJOR 0 +#define TRITONREPOAGENT_API_VERSION_MINOR 1 + +/// Get the TRITONREPOAGENT API version supported by Triton. This +/// value can be compared against the +/// TRITONREPOAGENT_API_VERSION_MAJOR and +/// TRITONREPOAGENT_API_VERSION_MINOR used to build the agent to +/// ensure that Triton is compatible with the agent. +/// +/// \param major Returns the TRITONREPOAGENT API major version supported +/// by Triton. +/// \param minor Returns the TRITONREPOAGENT API minor version supported +/// by Triton. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONREPOAGENT_DECLSPEC TRITONSERVER_Error* TRITONREPOAGENT_ApiVersion( + uint32_t* major, uint32_t* minor); + +/// TRITONREPOAGENT_ArtifactType +/// +/// The ways that the files that make up a model's repository content +/// are communicated between Triton and the agent. +/// +/// TRITONREPOAGENT_ARTIFACT_FILESYSTEM: The model artifacts are +/// communicated to and from the repository agent via a locally +/// accessible filesystem. The agent can access these files using +/// an appropriate filesystem API. +/// +/// TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM: The model artifacts are +/// communicated to and from the repository agent via a remote filesystem. +/// The remote filesystem path follows the same convention as is used for +/// repository paths, for example, "s3://" prefix indicates an S3 path. +/// +typedef enum TRITONREPOAGENT_artifacttype_enum { + TRITONREPOAGENT_ARTIFACT_FILESYSTEM, + TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM +} TRITONREPOAGENT_ArtifactType; + +/// TRITONREPOAGENT_ActionType +/// +/// Types of repository actions that can be handled by an agent. +/// The lifecycle of a TRITONREPOAGENT_AgentModel begins with a call to +/// TRITONREPOAGENT_ModelInitialize and ends with a call to +/// TRITONREPOAGENT_ModelFinalize. Between those calls the current lifecycle +/// state of the model is communicated by calls to TRITONREPOAGENT_ModelAction. +/// Possible lifecycles are: +/// +/// LOAD -> LOAD_COMPLETE -> UNLOAD -> UNLOAD_COMPLETE +/// LOAD -> LOAD_FAIL +/// +/// TRITONREPOAGENT_ACTION_LOAD: A model is being loaded. +/// +/// TRITONREPOAGENT_ACTION_LOAD_COMPLETE: The model load completed +/// successfully and the model is now loaded. +/// +/// TRITONREPOAGENT_ACTION_LOAD_FAIL: The model load did not complete +/// successfully. The model is not loaded. +/// +/// TRITONREPOAGENT_ACTION_UNLOAD: The model is being unloaded. +/// +/// TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE: The model unload is complete. +/// +typedef enum TRITONREPOAGENT_actiontype_enum { + TRITONREPOAGENT_ACTION_LOAD, + TRITONREPOAGENT_ACTION_LOAD_COMPLETE, + TRITONREPOAGENT_ACTION_LOAD_FAIL, + TRITONREPOAGENT_ACTION_UNLOAD, + TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE +} TRITONREPOAGENT_ActionType; + +/// Get the location of the files that make up the model. The +/// 'location' communicated depends on how the model is being +/// communicated to the agent as indicated by 'artifact_type'. +/// +/// TRITONREPOAGENT_ARTIFACT_FILESYSTEM: The model artifacts are +/// made available to the agent via the local +/// filesytem. 'location' returns the full path to the directory +/// in the model repository that contains the model's +/// artifacts. The returned location string is owned by Triton, +/// not the caller, and so should not be modified or freed. The +/// contents of the directory are owned by Triton, not the agent, +/// and so the agent should not delete or modify the contents. Use +/// TRITONREPOAGENT_RepositoryAcquire to get a location that can be +/// used to modify the model repository contents. +/// +/// TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM: The model artifacts are +/// made available to the agent via a remote filesystem. +/// 'location' returns the full path to the remote directory that contains +/// the model's artifacts. The returned location string is owned by Triton, +/// not the caller, and so should not be modified or freed. The contents of +/// the remote directory are owned by Triton, not the agent, +/// and so the agent should not delete or modify the contents. +/// Use TRITONREPOAGENT_ModelRepositoryLocationAcquire to get a location +/// that can be used to write updated model repository contents. +/// +/// \param agent The agent. +/// \param model The model. +/// \param artifact_type Returns the artifact type for the location. +/// \param path Returns the location. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONREPOAGENT_DECLSPEC TRITONSERVER_Error* +TRITONREPOAGENT_ModelRepositoryLocation( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + TRITONREPOAGENT_ArtifactType* artifact_type, const char** location); + +/// Acquire a location where the agent can produce a new version of +/// the model repository files. This is a convenience method to create +/// a temporary directory for the agent. The agent is responsible for +/// calling TRITONREPOAGENT_ModelRepositoryLocationDelete in +/// TRITONREPOAGENT_ModelFinalize to delete the location. Initially the +/// acquired location is empty. The 'location' communicated depends on +/// the requested 'artifact_type'. +/// +/// TRITONREPOAGENT_ARTIFACT_FILESYSTEM: The location is a directory +/// on the local filesystem. 'location' returns the full path to +/// an empty directory that the agent should populate with the +/// model's artifacts. The returned location string is owned by +/// Triton, not the agent, and so should not be modified or freed. +/// +/// \param agent The agent. +/// \param model The model. +/// \param artifact_type The artifact type for the location. +/// \param path Returns the location. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONREPOAGENT_DECLSPEC TRITONSERVER_Error* +TRITONREPOAGENT_ModelRepositoryLocationAcquire( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ArtifactType artifact_type, const char** location); + +/// Discard and release ownership of a previously acquired location +/// and its contents. The agent must not access or modify the location +/// or its contents after this call. +/// +/// \param agent The agent. +/// \param model The model. +/// \param path The location to release. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONREPOAGENT_DECLSPEC TRITONSERVER_Error* +TRITONREPOAGENT_ModelRepositoryLocationRelease( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const char* location); + +/// Inform Triton that the specified repository location should be used for +/// the model in place of the original model repository. This method can only be +/// called when TRITONREPOAGENT_ModelAction is invoked with +/// TRITONREPOAGENT_ACTION_LOAD. The 'location' The 'location' +/// communicated depends on how the repository is being +/// communicated to Triton as indicated by 'artifact_type'. +/// +/// TRITONREPOAGENT_ARTIFACT_FILESYSTEM: The model artifacts are +/// made available to Triton via the local filesytem. 'location' returns +/// the full path to the directory. Ownership of the contents of the +/// returned directory are transferred to Triton and the agent should not +/// modified or freed the contents until TRITONREPOAGENT_ModelFinalize. +/// The local filesystem directory can be created using +/// TRITONREPOAGENT_ModelReopsitroyLocationAcquire or the agent can use +/// its own local filesystem API. +/// +/// TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM: The model artifacts are +/// made available to Triton via a remote filesystem. 'location' returns +/// the full path to the remote filesystem directory. Ownership of the +/// contents of the returned directory are transferred to Triton and +/// the agent should not modified or freed the contents until +/// TRITONREPOAGENT_ModelFinalize. +/// +/// \param agent The agent. +/// \param model The model. +/// \param artifact_type The artifact type for the location. +/// \param path Returns the location. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONREPOAGENT_DECLSPEC TRITONSERVER_Error* +TRITONREPOAGENT_ModelRepositoryUpdate( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ArtifactType artifact_type, const char* location); + +/// Get the number of agent parameters defined for a model. +/// +/// \param agent The agent. +/// \param model The model. +/// \param count Returns the number of input tensors. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONREPOAGENT_DECLSPEC TRITONSERVER_Error* +TRITONREPOAGENT_ModelParameterCount( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + uint32_t* count); + +/// Get a parameter name and value. The caller does not own the +/// returned strings and must not modify or delete them. +/// +/// \param agent The agent. +/// \param model The model. +/// \param index The index of the parameter. Must be 0 <= index < +/// count, where count is the value returned by +/// TRITONREPOAGENT_ModelParameterCount. +/// \param parameter_name Returns the name of the parameter. +/// \param parameter_value Returns the value of the parameter. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONREPOAGENT_DECLSPEC TRITONSERVER_Error* TRITONREPOAGENT_ModelParameter( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const uint32_t index, const char** parameter_name, + const char** parameter_value); + +/// Get the model configuration. The caller takes ownership of the +/// message object and must call TRITONSERVER_MessageDelete to release +/// the object. If the model repository does not contain a +/// config.pbtxt file then 'model_config' is returned as nullptr. +/// +/// \param agent The agent. +/// \param model The model. +/// \param config_version The model configuration will be returned in +/// a format matching this version. If the configuration cannot be +/// represented in the requested version's format then an error will +/// be returned. Currently only version 1 is supported. +/// \param model_config Returns the model configuration as a message. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONREPOAGENT_DECLSPEC TRITONSERVER_Error* TRITONREPOAGENT_ModelConfig( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const uint32_t config_version, TRITONSERVER_Message** model_config); + +/// Get the user-specified state associated with the model. +/// +/// \param model The agent model. +/// \param state Returns the user state, or nullptr if no user state. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONREPOAGENT_DECLSPEC TRITONSERVER_Error* TRITONREPOAGENT_ModelState( + TRITONREPOAGENT_AgentModel* model, void** state); + +/// Set the user-specified state associated with the model. +/// +/// \param model The agent model. +/// \param state The user state, or nullptr if no user state. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONREPOAGENT_DECLSPEC TRITONSERVER_Error* TRITONREPOAGENT_ModelSetState( + TRITONREPOAGENT_AgentModel* model, void* state); + +/// Get the user-specified state associated with the agent. +/// +/// \param agent The agent. +/// \param state Returns the user state, or nullptr if no user state. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONREPOAGENT_DECLSPEC TRITONSERVER_Error* TRITONREPOAGENT_State( + TRITONREPOAGENT_Agent* agent, void** state); + +/// Set the user-specified state associated with the agent. +/// +/// \param agent The agent. +/// \param state The user state, or nullptr if no user state. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONREPOAGENT_DECLSPEC TRITONSERVER_Error* TRITONREPOAGENT_SetState( + TRITONREPOAGENT_Agent* agent, void* state); + +/// +/// The following functions can be implemented by an agent. Functions +/// indicated as required must be implemented or the agent will fail +/// to load. +/// + +/// Initialize an agent. This function is optional. This function is +/// called once when an agent is loaded to allow the agent to +/// initialize any state associated with the agent. An agent has a +/// single state that is shared across all invocations of the agent. +/// +/// \param agent The agent. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONREPOAGENT_ISPEC TRITONSERVER_Error* TRITONREPOAGENT_Initialize( + TRITONREPOAGENT_Agent* agent); + +/// Finalize for an agent. This function is optional. This function is +/// called once, just before the agent is unloaded. All state +/// associated with the agent should be freed and any threads created +/// for the agent should be exited/joined before returning from this +/// function. +/// +/// \param agent The agent. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONREPOAGENT_ISPEC TRITONSERVER_Error* TRITONREPOAGENT_Finalize( + TRITONREPOAGENT_Agent* agent); + +/// Initialize a model associated with an agent. This function is optional. +/// This function is called once when an agent model's lifecycle begins to allow +/// the agent model to initialize any state associated with it. An agent model +/// has a single state that is shared across all the lifecycle of the agent +/// model. +/// +/// \param agent The agent to be associated with the model. +/// \param model The model. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONREPOAGENT_ISPEC TRITONSERVER_Error* TRITONREPOAGENT_ModelInitialize( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model); + +/// Finalize for a model. This function is optional. This function is +/// called once, just before the end of the agent model's lifecycle. All state +/// associated with the agent model should be freed and any threads created +/// for the agent model should be exited/joined before returning from this +/// function. If the model acquired a model location using +/// TRITONREPOAGENT_ModelRepositoryLocationAcquire, it must call +/// TRITONREPOAGENT_ModelRepositoryLocationRelease to release that location. +/// +/// \param agent The agent associated with the model. +/// \param model The model. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONREPOAGENT_ISPEC TRITONSERVER_Error* TRITONREPOAGENT_ModelFinalize( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model); + +/// Handle an action for a specified model. This function is +/// required. Triton will not perform multiple simultaneous calls to +/// this function for a given agent and model; however, there may be +/// simultaneous calls for the agent for different models. +/// +/// If the agent does not handle the action the agent should +/// immediately return success (nullptr). +/// +/// Any modification to the model's repository must be made when 'action_type' +/// is TRITONREPOAGENT_ACTION_LOAD. +/// To modify the model's repository the agent must either acquire a mutable +/// location via TRITONREPOAGENT_ModelRepositoryLocationAcquire +/// or its own managed location, report the location to Triton via +/// TRITONREPOAGENT_ModelRepositoryUpdate, and then return +/// success (nullptr). If the agent does not need to make any changes +/// to the model repository it should not call +/// TRITONREPOAGENT_ModelRepositoryUpdate and then return success. +/// To indicate that a model load should fail return a non-success status. +/// +/// \param agent The agent. +/// \param model The model that is the target of the action. +/// \action_type The type of action the agent should handle for the model. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONREPOAGENT_ISPEC TRITONSERVER_Error* TRITONREPOAGENT_ModelAction( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ActionType action_type); + +#ifdef __cplusplus +} +#endif diff --git a/3rdparty/core-r22.12/include/triton/core/tritonserver.h b/3rdparty/core-r22.12/include/triton/core/tritonserver.h new file mode 100644 index 0000000000000000000000000000000000000000..6edd5f1809116166215e4b1702b12dfba7f19de4 --- /dev/null +++ b/3rdparty/core-r22.12/include/triton/core/tritonserver.h @@ -0,0 +1,2360 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +/// \file + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef _COMPILING_TRITONSERVER +#if defined(_MSC_VER) +#define TRITONSERVER_DECLSPEC __declspec(dllexport) +#elif defined(__GNUC__) +#define TRITONSERVER_DECLSPEC __attribute__((__visibility__("default"))) +#else +#define TRITONSERVER_DECLSPEC +#endif +#else +#if defined(_MSC_VER) +#define TRITONSERVER_DECLSPEC __declspec(dllimport) +#else +#define TRITONSERVER_DECLSPEC +#endif +#endif + +struct TRITONSERVER_BufferAttributes; +struct TRITONSERVER_Error; +struct TRITONSERVER_InferenceRequest; +struct TRITONSERVER_InferenceResponse; +struct TRITONSERVER_InferenceTrace; +struct TRITONSERVER_Message; +struct TRITONSERVER_Metrics; +struct TRITONSERVER_Parameter; +struct TRITONSERVER_ResponseAllocator; +struct TRITONSERVER_Server; +struct TRITONSERVER_ServerOptions; +struct TRITONSERVER_Metric; +struct TRITONSERVER_MetricFamily; + +/// +/// TRITONSERVER API Version +/// +/// The TRITONSERVER API is versioned with major and minor version +/// numbers. Any change to the API that does not impact backwards +/// compatibility (for example, adding a non-required function) +/// increases the minor version number. Any change that breaks +/// backwards compatibility (for example, deleting or changing the +/// behavior of a function) increases the major version number. A +/// client should check that the API version used to compile the +/// client is compatible with the API version of the Triton shared +/// library that it is linking against. This is typically done by code +/// similar to the following which makes sure that the major versions +/// are equal and that the minor version of the Triton shared library +/// is >= the minor version used to build the client. +/// +/// uint32_t api_version_major, api_version_minor; +/// TRITONSERVER_ApiVersion(&api_version_major, &api_version_minor); +/// if ((api_version_major != TRITONSERVER_API_VERSION_MAJOR) || +/// (api_version_minor < TRITONSERVER_API_VERSION_MINOR)) { +/// return TRITONSERVER_ErrorNew( +/// TRITONSERVER_ERROR_UNSUPPORTED, +/// "triton server API version does not support this client"); +/// } +/// +#define TRITONSERVER_API_VERSION_MAJOR 1 +#define TRITONSERVER_API_VERSION_MINOR 17 + +/// Get the TRITONBACKEND API version supported by the Triton shared +/// library. This value can be compared against the +/// TRITONSERVER_API_VERSION_MAJOR and TRITONSERVER_API_VERSION_MINOR +/// used to build the client to ensure that Triton shared library is +/// compatible with the client. +/// +/// \param major Returns the TRITONSERVER API major version supported +/// by Triton. +/// \param minor Returns the TRITONSERVER API minor version supported +/// by Triton. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ApiVersion( + uint32_t* major, uint32_t* minor); + +/// TRITONSERVER_DataType +/// +/// Tensor data types recognized by TRITONSERVER. +/// +typedef enum TRITONSERVER_datatype_enum { + TRITONSERVER_TYPE_INVALID, + TRITONSERVER_TYPE_BOOL, + TRITONSERVER_TYPE_UINT8, + TRITONSERVER_TYPE_UINT16, + TRITONSERVER_TYPE_UINT32, + TRITONSERVER_TYPE_UINT64, + TRITONSERVER_TYPE_INT8, + TRITONSERVER_TYPE_INT16, + TRITONSERVER_TYPE_INT32, + TRITONSERVER_TYPE_INT64, + TRITONSERVER_TYPE_FP16, + TRITONSERVER_TYPE_FP32, + TRITONSERVER_TYPE_FP64, + TRITONSERVER_TYPE_BYTES, + TRITONSERVER_TYPE_BF16 +} TRITONSERVER_DataType; + +/// Get the string representation of a data type. The returned string +/// is not owned by the caller and so should not be modified or freed. +/// +/// \param datatype The data type. +/// \return The string representation of the data type. +TRITONSERVER_DECLSPEC const char* TRITONSERVER_DataTypeString( + TRITONSERVER_DataType datatype); + +/// Get the Triton datatype corresponding to a string representation +/// of a datatype. +/// +/// \param dtype The datatype string representation. +/// \return The Triton data type or TRITONSERVER_TYPE_INVALID if the +/// string does not represent a data type. +TRITONSERVER_DECLSPEC TRITONSERVER_DataType +TRITONSERVER_StringToDataType(const char* dtype); + +/// Get the size of a Triton datatype in bytes. Zero is returned for +/// TRITONSERVER_TYPE_BYTES because it have variable size. Zero is +/// returned for TRITONSERVER_TYPE_INVALID. +/// +/// \param dtype The datatype. +/// \return The size of the datatype. +TRITONSERVER_DECLSPEC uint32_t +TRITONSERVER_DataTypeByteSize(TRITONSERVER_DataType datatype); + +/// TRITONSERVER_MemoryType +/// +/// Types of memory recognized by TRITONSERVER. +/// +typedef enum TRITONSERVER_memorytype_enum { + TRITONSERVER_MEMORY_CPU, + TRITONSERVER_MEMORY_CPU_PINNED, + TRITONSERVER_MEMORY_GPU +} TRITONSERVER_MemoryType; + +/// Get the string representation of a memory type. The returned +/// string is not owned by the caller and so should not be modified or +/// freed. +/// +/// \param memtype The memory type. +/// \return The string representation of the memory type. +TRITONSERVER_DECLSPEC const char* TRITONSERVER_MemoryTypeString( + TRITONSERVER_MemoryType memtype); + +/// TRITONSERVER_ParameterType +/// +/// Types of parameters recognized by TRITONSERVER. +/// +typedef enum TRITONSERVER_parametertype_enum { + TRITONSERVER_PARAMETER_STRING, + TRITONSERVER_PARAMETER_INT, + TRITONSERVER_PARAMETER_BOOL, + TRITONSERVER_PARAMETER_BYTES +} TRITONSERVER_ParameterType; + +/// Get the string representation of a parameter type. The returned +/// string is not owned by the caller and so should not be modified or +/// freed. +/// +/// \param paramtype The parameter type. +/// \return The string representation of the parameter type. +TRITONSERVER_DECLSPEC const char* TRITONSERVER_ParameterTypeString( + TRITONSERVER_ParameterType paramtype); + +/// Create a new parameter object. The caller takes ownership of the +/// TRITONSERVER_Parameter object and must call TRITONSERVER_ParameterDelete to +/// release the object. The object will maintain its own copy of the 'value' +/// +/// \param name The parameter name. +/// \param type The parameter type. +/// \param value The pointer to the value. +/// \return A new TRITONSERVER_Parameter object. 'nullptr' will be returned if +/// 'type' is 'TRITONSERVER_PARAMETER_BYTES'. The caller should use +/// TRITONSERVER_ParameterBytesNew to create parameter with bytes type. +TRITONSERVER_DECLSPEC TRITONSERVER_Parameter* TRITONSERVER_ParameterNew( + const char* name, const TRITONSERVER_ParameterType type, const void* value); + +/// Create a new parameter object with type TRITONSERVER_PARAMETER_BYTES. +/// The caller takes ownership of the TRITONSERVER_Parameter object and must +/// call TRITONSERVER_ParameterDelete to release the object. The object only +/// maintains a shallow copy of the 'byte_ptr' so the data content must be +/// valid until the parameter object is deleted. +/// +/// \param name The parameter name. +/// \param byte_ptr The pointer to the data content. +/// \param size The size of the data content. +/// \return A new TRITONSERVER_Error object. +TRITONSERVER_DECLSPEC TRITONSERVER_Parameter* TRITONSERVER_ParameterBytesNew( + const char* name, const void* byte_ptr, const uint64_t size); + +/// Delete an parameter object. +/// +/// \param parameter The parameter object. +TRITONSERVER_DECLSPEC void TRITONSERVER_ParameterDelete( + TRITONSERVER_Parameter* parameter); + +/// TRITONSERVER_InstanceGroupKind +/// +/// Kinds of instance groups recognized by TRITONSERVER. +/// +typedef enum TRITONSERVER_instancegroupkind_enum { + TRITONSERVER_INSTANCEGROUPKIND_AUTO, + TRITONSERVER_INSTANCEGROUPKIND_CPU, + TRITONSERVER_INSTANCEGROUPKIND_GPU, + TRITONSERVER_INSTANCEGROUPKIND_MODEL +} TRITONSERVER_InstanceGroupKind; + +/// Get the string representation of an instance-group kind. The +/// returned string is not owned by the caller and so should not be +/// modified or freed. +/// +/// \param kind The instance-group kind. +/// \return The string representation of the kind. +TRITONSERVER_DECLSPEC const char* TRITONSERVER_InstanceGroupKindString( + TRITONSERVER_InstanceGroupKind kind); + +/// TRITONSERVER_Logging +/// +/// Types/levels of logging. +/// +typedef enum TRITONSERVER_loglevel_enum { + TRITONSERVER_LOG_INFO, + TRITONSERVER_LOG_WARN, + TRITONSERVER_LOG_ERROR, + TRITONSERVER_LOG_VERBOSE +} TRITONSERVER_LogLevel; + +/// +/// Format of logging. +/// +/// TRITONSERVER_LOG_DEFAULT: the log severity (L) and timestamp will be +/// logged as "LMMDD hh:mm:ss.ssssss". +/// +/// TRITONSERVER_LOG_ISO8601: the log format will be "YYYY-MM-DDThh:mm:ssZ L". +/// +typedef enum TRITONSERVER_logformat_enum { + TRITONSERVER_LOG_DEFAULT, + TRITONSERVER_LOG_ISO8601 +} TRITONSERVER_LogFormat; + +/// Is a log level enabled? +/// +/// \param level The log level. +/// \return True if the log level is enabled, false if not enabled. +TRITONSERVER_DECLSPEC bool TRITONSERVER_LogIsEnabled( + TRITONSERVER_LogLevel level); + +/// Log a message at a given log level if that level is enabled. +/// +/// \param level The log level. +/// \param filename The file name of the location of the log message. +/// \param line The line number of the log message. +/// \param msg The log message. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_LogMessage( + TRITONSERVER_LogLevel level, const char* filename, const int line, + const char* msg); + +/// TRITONSERVER_Error +/// +/// Errors are reported by a TRITONSERVER_Error object. A NULL +/// TRITONSERVER_Error indicates no error, a non-NULL TRITONSERVER_Error +/// indicates error and the code and message for the error can be +/// retrieved from the object. +/// +/// The caller takes ownership of a TRITONSERVER_Error object returned by +/// the API and must call TRITONSERVER_ErrorDelete to release the object. +/// + +/// The TRITONSERVER_Error error codes +typedef enum TRITONSERVER_errorcode_enum { + TRITONSERVER_ERROR_UNKNOWN, + TRITONSERVER_ERROR_INTERNAL, + TRITONSERVER_ERROR_NOT_FOUND, + TRITONSERVER_ERROR_INVALID_ARG, + TRITONSERVER_ERROR_UNAVAILABLE, + TRITONSERVER_ERROR_UNSUPPORTED, + TRITONSERVER_ERROR_ALREADY_EXISTS +} TRITONSERVER_Error_Code; + +/// Create a new error object. The caller takes ownership of the +/// TRITONSERVER_Error object and must call TRITONSERVER_ErrorDelete to +/// release the object. +/// +/// \param code The error code. +/// \param msg The error message. +/// \return A new TRITONSERVER_Error object. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ErrorNew( + TRITONSERVER_Error_Code code, const char* msg); + +/// Delete an error object. +/// +/// \param error The error object. +TRITONSERVER_DECLSPEC void TRITONSERVER_ErrorDelete(TRITONSERVER_Error* error); + +/// Get the error code. +/// +/// \param error The error object. +/// \return The error code. +TRITONSERVER_DECLSPEC TRITONSERVER_Error_Code +TRITONSERVER_ErrorCode(TRITONSERVER_Error* error); + +/// Get the string representation of an error code. The returned +/// string is not owned by the caller and so should not be modified or +/// freed. The lifetime of the returned string extends only as long as +/// 'error' and must not be accessed once 'error' is deleted. +/// +/// \param error The error object. +/// \return The string representation of the error code. +TRITONSERVER_DECLSPEC const char* TRITONSERVER_ErrorCodeString( + TRITONSERVER_Error* error); + +/// Get the error message. The returned string is not owned by the +/// caller and so should not be modified or freed. The lifetime of the +/// returned string extends only as long as 'error' and must not be +/// accessed once 'error' is deleted. +/// +/// \param error The error object. +/// \return The error message. +TRITONSERVER_DECLSPEC const char* TRITONSERVER_ErrorMessage( + TRITONSERVER_Error* error); + +/// TRITONSERVER_ResponseAllocator +/// +/// Object representing a memory allocator for output tensors in an +/// inference response. +/// + +/// Type for allocation function that allocates a buffer to hold an +/// output tensor. +/// +/// \param allocator The allocator that is provided in the call to +/// TRITONSERVER_InferenceRequestSetResponseCallback. +/// \param tensor_name The name of the output tensor to allocate for. +/// \param byte_size The size of the buffer to allocate. +/// \param memory_type The type of memory that the caller prefers for +/// the buffer allocation. +/// \param memory_type_id The ID of the memory that the caller prefers +/// for the buffer allocation. +/// \param userp The user data pointer that is provided as +/// 'response_allocator_userp' in the call to +/// TRITONSERVER_InferenceRequestSetResponseCallback. +/// \param buffer Returns a pointer to the allocated memory. +/// \param buffer_userp Returns a user-specified value to associate +/// with the buffer, or nullptr if no user-specified value should be +/// associated with the buffer. This value will be provided in the +/// call to TRITONSERVER_ResponseAllocatorReleaseFn_t when the buffer +/// is released and will also be returned by +/// TRITONSERVER_InferenceResponseOutput. +/// \param actual_memory_type Returns the type of memory where the +/// allocation resides. May be different than the type of memory +/// requested by 'memory_type'. +/// \param actual_memory_type_id Returns the ID of the memory where +/// the allocation resides. May be different than the ID of the memory +/// requested by 'memory_type_id'. +/// \return a TRITONSERVER_Error object if a failure occurs while +/// attempting an allocation. If an error is returned all other return +/// values will be ignored. +typedef TRITONSERVER_Error* (*TRITONSERVER_ResponseAllocatorAllocFn_t)( + TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name, + size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id, void* userp, void** buffer, void** buffer_userp, + TRITONSERVER_MemoryType* actual_memory_type, + int64_t* actual_memory_type_id); + +/// Type for allocation function that allocates a buffer to hold an +/// output tensor with buffer attributes. The callback function must fill in the +/// appropriate buffer attributes information related to this buffer. If set, +/// this function is always called after TRITONSERVER_ResponseAllocatorAllocFn_t +/// function. +/// +/// \param allocator The allocator that is provided in the call to +/// TRITONSERVER_InferenceRequestSetResponseCallback. +/// \param tensor_name The name of the output tensor to allocate for. +/// \param buffer_attributes The buffer attributes associated with the buffer. +/// \param userp The user data pointer that is provided as +/// 'response_allocator_userp' in the call to +/// TRITONSERVER_InferenceRequestSetResponseCallback. +/// \param buffer_userp Returns a user-specified value to associate +/// with the buffer, or nullptr if no user-specified value should be +/// associated with the buffer. This value will be provided in the +/// call to TRITONSERVER_ResponseAllocatorReleaseFn_t when the buffer +/// is released and will also be returned by +/// TRITONSERVER_InferenceResponseOutput. +/// \return a TRITONSERVER_Error object if a failure occurs while +/// attempting an allocation. If an error is returned all other return +/// values will be ignored. +typedef TRITONSERVER_Error* ( + *TRITONSERVER_ResponseAllocatorBufferAttributesFn_t)( + TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name, + TRITONSERVER_BufferAttributes* buffer_attributes, void* userp, + void* buffer_userp); + +/// Type for function that is called to query the allocator's preferred memory +/// type and memory type ID. As much as possible, the allocator should attempt +/// to return the same memory_type and memory_type_id values that will be +/// returned by the subsequent call to TRITONSERVER_ResponseAllocatorAllocFn_t. +/// But the allocator is not required to do so. +/// +/// \param allocator The allocator that is provided in the call to +/// TRITONSERVER_InferenceRequestSetResponseCallback. +/// \param userp The user data pointer that is provided as +/// 'response_allocator_userp' in the call to +/// TRITONSERVER_InferenceRequestSetResponseCallback. +/// \param tensor_name The name of the output tensor. This is optional +/// and it should be set to nullptr to indicate that the tensor name has +/// not determined. +/// \param byte_size The expected size of the buffer. This is optional +/// and it should be set to nullptr to indicate that the byte size has +/// not determined. +/// \param memory_type Acts as both input and output. On input gives +/// the memory type preferred by the caller. Returns memory type preferred +/// by the allocator, taken account of the caller preferred type. +/// \param memory_type_id Acts as both input and output. On input gives +/// the memory type ID preferred by the caller. Returns memory type ID preferred +/// by the allocator, taken account of the caller preferred type ID. +/// \return a TRITONSERVER_Error object if a failure occurs. +typedef TRITONSERVER_Error* (*TRITONSERVER_ResponseAllocatorQueryFn_t)( + TRITONSERVER_ResponseAllocator* allocator, void* userp, + const char* tensor_name, size_t* byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id); + +/// Type for function that is called when the server no longer holds +/// any reference to a buffer allocated by +/// TRITONSERVER_ResponseAllocatorAllocFn_t. In practice this function +/// is typically called when the response object associated with the +/// buffer is deleted by TRITONSERVER_InferenceResponseDelete. +/// +/// \param allocator The allocator that is provided in the call to +/// TRITONSERVER_InferenceRequestSetResponseCallback. +/// \param buffer Pointer to the buffer to be freed. +/// \param buffer_userp The user-specified value associated +/// with the buffer in TRITONSERVER_ResponseAllocatorAllocFn_t. +/// \param byte_size The size of the buffer. +/// \param memory_type The type of memory holding the buffer. +/// \param memory_type_id The ID of the memory holding the buffer. +/// \return a TRITONSERVER_Error object if a failure occurs while +/// attempting the release. If an error is returned Triton will not +/// attempt to release the buffer again. +typedef TRITONSERVER_Error* (*TRITONSERVER_ResponseAllocatorReleaseFn_t)( + TRITONSERVER_ResponseAllocator* allocator, void* buffer, void* buffer_userp, + size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id); + +/// Type for function that is called to indicate that subsequent +/// allocation requests will refer to a new response. +/// +/// \param allocator The allocator that is provided in the call to +/// TRITONSERVER_InferenceRequestSetResponseCallback. +/// \param userp The user data pointer that is provided as +/// 'response_allocator_userp' in the call to +/// TRITONSERVER_InferenceRequestSetResponseCallback. +/// \return a TRITONSERVER_Error object if a failure occurs. +typedef TRITONSERVER_Error* (*TRITONSERVER_ResponseAllocatorStartFn_t)( + TRITONSERVER_ResponseAllocator* allocator, void* userp); + +/// Create a new response allocator object. +/// +/// The response allocator object is used by Triton to allocate +/// buffers to hold the output tensors in inference responses. Most +/// models generate a single response for each inference request +/// (TRITONSERVER_TXN_ONE_TO_ONE). For these models the order of +/// callbacks will be: +/// +/// TRITONSERVER_ServerInferAsync called +/// - start_fn : optional (and typically not required) +/// - alloc_fn : called once for each output tensor in response +/// TRITONSERVER_InferenceResponseDelete called +/// - release_fn: called once for each output tensor in response +/// +/// For models that generate multiple responses for each inference +/// request (TRITONSERVER_TXN_DECOUPLED), the start_fn callback can be +/// used to determine sets of alloc_fn callbacks that belong to the +/// same response: +/// +/// TRITONSERVER_ServerInferAsync called +/// - start_fn +/// - alloc_fn : called once for each output tensor in response +/// - start_fn +/// - alloc_fn : called once for each output tensor in response +/// ... +/// For each response, TRITONSERVER_InferenceResponseDelete called +/// - release_fn: called once for each output tensor in the response +/// +/// In all cases the start_fn, alloc_fn and release_fn callback +/// functions must be thread-safe. Typically making these functions +/// thread-safe does not require explicit locking. The recommended way +/// to implement these functions is to have each inference request +/// provide a 'response_allocator_userp' object that is unique to that +/// request with TRITONSERVER_InferenceRequestSetResponseCallback. The +/// callback functions then operate only on this unique state. Locking +/// is required only when the callback function needs to access state +/// that is shared across inference requests (for example, a common +/// allocation pool). +/// +/// \param allocator Returns the new response allocator object. +/// \param alloc_fn The function to call to allocate buffers for result +/// tensors. +/// \param release_fn The function to call when the server no longer +/// holds a reference to an allocated buffer. +/// \param start_fn The function to call to indicate that the +/// subsequent 'alloc_fn' calls are for a new response. This callback +/// is optional (use nullptr to indicate that it should not be +/// invoked). +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ResponseAllocatorNew( + TRITONSERVER_ResponseAllocator** allocator, + TRITONSERVER_ResponseAllocatorAllocFn_t alloc_fn, + TRITONSERVER_ResponseAllocatorReleaseFn_t release_fn, + TRITONSERVER_ResponseAllocatorStartFn_t start_fn); + +/// Set the buffer attributes function for a response allocator object. +/// The function will be called after alloc_fn to set the buffer attributes +/// associated with the output buffer. +/// +/// The thread-safy requirement for buffer_attributes_fn is the same as other +/// allocator callbacks. +/// +/// \param allocator The response allocator object. +/// \param buffer_attributes_fn The function to call to get the buffer +/// attributes information for an allocated buffer. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ResponseAllocatorSetBufferAttributesFunction( + TRITONSERVER_ResponseAllocator* allocator, + TRITONSERVER_ResponseAllocatorBufferAttributesFn_t buffer_attributes_fn); + +/// Set the query function to a response allocator object. Usually the +/// function will be called before alloc_fn to understand what is the +/// allocator's preferred memory type and memory type ID at the current +/// situation to make different execution decision. +/// +/// The thread-safy requirement for query_fn is the same as other allocator +/// callbacks. +/// +/// \param allocator The response allocator object. +/// \param query_fn The function to call to query allocator's preferred memory +/// type and memory type ID. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ResponseAllocatorSetQueryFunction( + TRITONSERVER_ResponseAllocator* allocator, + TRITONSERVER_ResponseAllocatorQueryFn_t query_fn); + +/// Delete a response allocator. +/// +/// \param allocator The response allocator object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ResponseAllocatorDelete( + TRITONSERVER_ResponseAllocator* allocator); + +/// TRITONSERVER_Message +/// +/// Object representing a Triton Server message. +/// + +/// Create a new message object from serialized JSON string. +/// +/// \param message The message object. +/// \param base The base of the serialized JSON. +/// \param byte_size The size, in bytes, of the serialized message. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_MessageNewFromSerializedJson( + TRITONSERVER_Message** message, const char* base, size_t byte_size); + +/// Delete a message object. +/// +/// \param message The message object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_MessageDelete( + TRITONSERVER_Message* message); + +/// Get the base and size of the buffer containing the serialized +/// message in JSON format. The buffer is owned by the +/// TRITONSERVER_Message object and should not be modified or freed by +/// the caller. The lifetime of the buffer extends only as long as +/// 'message' and must not be accessed once 'message' is deleted. +/// +/// \param message The message object. +/// \param base Returns the base of the serialized message. +/// \param byte_size Returns the size, in bytes, of the serialized +/// message. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_MessageSerializeToJson( + TRITONSERVER_Message* message, const char** base, size_t* byte_size); + +/// TRITONSERVER_Metrics +/// +/// Object representing metrics. +/// + +/// Metric format types +typedef enum tritonserver_metricformat_enum { + TRITONSERVER_METRIC_PROMETHEUS +} TRITONSERVER_MetricFormat; + +/// Delete a metrics object. +/// +/// \param metrics The metrics object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_MetricsDelete( + TRITONSERVER_Metrics* metrics); + +/// Get a buffer containing the metrics in the specified format. For +/// each format the buffer contains the following: +/// +/// TRITONSERVER_METRIC_PROMETHEUS: 'base' points to a single multiline +/// string (char*) that gives a text representation of the metrics in +/// prometheus format. 'byte_size' returns the length of the string +/// in bytes. +/// +/// The buffer is owned by the 'metrics' object and should not be +/// modified or freed by the caller. The lifetime of the buffer +/// extends only as long as 'metrics' and must not be accessed once +/// 'metrics' is deleted. +/// +/// \param metrics The metrics object. +/// \param format The format to use for the returned metrics. +/// \param base Returns a pointer to the base of the formatted +/// metrics, as described above. +/// \param byte_size Returns the size, in bytes, of the formatted +/// metrics. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_MetricsFormatted( + TRITONSERVER_Metrics* metrics, TRITONSERVER_MetricFormat format, + const char** base, size_t* byte_size); + +/// TRITONSERVER_InferenceTrace +/// +/// Object that represents tracing for an inference request. +/// + +/// Trace levels. The trace level controls the type of trace +/// activities that are reported for an inference request. +/// +/// Trace level values are power-of-2 and can be combined to trace +/// multiple types of activities. For example, use +/// (TRITONSERVER_TRACE_LEVEL_TIMESTAMPS | +/// TRITONSERVER_TRACE_LEVEL_TENSORS) to trace both timestamps and +/// tensors for an inference request. +/// +/// TRITONSERVER_TRACE_LEVEL_MIN and TRITONSERVER_TRACE_LEVEL_MAX are +/// deprecated and should not be used. +typedef enum tritonserver_tracelevel_enum { + /// Tracing disabled. No trace activities are reported. + TRITONSERVER_TRACE_LEVEL_DISABLED = 0, + /// Deprecated. Use TRITONSERVER_TRACE_LEVEL_TIMESTAMPS. + TRITONSERVER_TRACE_LEVEL_MIN = 1, + /// Deprecated. Use TRITONSERVER_TRACE_LEVEL_TIMESTAMPS. + TRITONSERVER_TRACE_LEVEL_MAX = 2, + /// Record timestamps for the inference request. + TRITONSERVER_TRACE_LEVEL_TIMESTAMPS = 0x4, + /// Record input and output tensor values for the inference request. + TRITONSERVER_TRACE_LEVEL_TENSORS = 0x8 +} TRITONSERVER_InferenceTraceLevel; + +/// Get the string representation of a trace level. The returned +/// string is not owned by the caller and so should not be modified or +/// freed. +/// +/// \param level The trace level. +/// \return The string representation of the trace level. +TRITONSERVER_DECLSPEC const char* TRITONSERVER_InferenceTraceLevelString( + TRITONSERVER_InferenceTraceLevel level); + +/// Trace activities +typedef enum tritonserver_traceactivity_enum { + TRITONSERVER_TRACE_REQUEST_START = 0, + TRITONSERVER_TRACE_QUEUE_START = 1, + TRITONSERVER_TRACE_COMPUTE_START = 2, + TRITONSERVER_TRACE_COMPUTE_INPUT_END = 3, + TRITONSERVER_TRACE_COMPUTE_OUTPUT_START = 4, + TRITONSERVER_TRACE_COMPUTE_END = 5, + TRITONSERVER_TRACE_REQUEST_END = 6, + TRITONSERVER_TRACE_TENSOR_QUEUE_INPUT = 7, + TRITONSERVER_TRACE_TENSOR_BACKEND_INPUT = 8, + TRITONSERVER_TRACE_TENSOR_BACKEND_OUTPUT = 9 +} TRITONSERVER_InferenceTraceActivity; + +/// Get the string representation of a trace activity. The returned +/// string is not owned by the caller and so should not be modified or +/// freed. +/// +/// \param activity The trace activity. +/// \return The string representation of the trace activity. +TRITONSERVER_DECLSPEC const char* TRITONSERVER_InferenceTraceActivityString( + TRITONSERVER_InferenceTraceActivity activity); + +/// Type for trace timeline activity callback function. This callback function +/// is used to report activity occurring for a trace. This function +/// does not take ownership of 'trace' and so any information needed +/// from that object must be copied before returning. The 'userp' data +/// is the same as what is supplied in the call to +/// TRITONSERVER_InferenceTraceNew. +typedef void (*TRITONSERVER_InferenceTraceActivityFn_t)( + TRITONSERVER_InferenceTrace* trace, + TRITONSERVER_InferenceTraceActivity activity, uint64_t timestamp_ns, + void* userp); + +/// Type for trace tensor activity callback function. This callback function +/// is used to report tensor activity occurring for a trace. This function +/// does not take ownership of 'trace' and so any information needed +/// from that object must be copied before returning. The 'userp' data +/// is the same as what is supplied in the call to +/// TRITONSERVER_InferenceTraceTensorNew. +typedef void (*TRITONSERVER_InferenceTraceTensorActivityFn_t)( + TRITONSERVER_InferenceTrace* trace, + TRITONSERVER_InferenceTraceActivity activity, const char* name, + TRITONSERVER_DataType datatype, const void* base, size_t byte_size, + const int64_t* shape, uint64_t dim_count, + TRITONSERVER_MemoryType memory_type, int64_t memory_type_id, void* userp); + +/// Type for trace release callback function. This callback function +/// is called when all activity for the trace has completed. The +/// callback function takes ownership of the +/// TRITONSERVER_InferenceTrace object. The 'userp' data is the same +/// as what is supplied in the call to TRITONSERVER_InferenceTraceNew. +typedef void (*TRITONSERVER_InferenceTraceReleaseFn_t)( + TRITONSERVER_InferenceTrace* trace, void* userp); + +/// Create a new inference trace object. The caller takes ownership of +/// the TRITONSERVER_InferenceTrace object and must call +/// TRITONSERVER_InferenceTraceDelete to release the object. +/// +/// The activity callback function will be called to report activity +/// for 'trace' as well as for any child traces that are spawned by +/// 'trace', and so the activity callback must check the trace object +/// to determine specifically what activity is being reported. +/// +/// The release callback is called for both 'trace' and for any child +/// traces spawned by 'trace'. +/// +/// \param trace Returns the new inference trace object. +/// \param level The tracing level. +/// \param parent_id The parent trace id for this trace. A value of 0 +/// indicates that there is not parent trace. +/// \param activity_fn The callback function where activity for the +/// trace is reported. +/// \param release_fn The callback function called when all activity +/// is complete for the trace. +/// \param trace_userp User-provided pointer that is delivered to +/// the activity and release callback functions. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceTraceNew( + TRITONSERVER_InferenceTrace** trace, TRITONSERVER_InferenceTraceLevel level, + uint64_t parent_id, TRITONSERVER_InferenceTraceActivityFn_t activity_fn, + TRITONSERVER_InferenceTraceReleaseFn_t release_fn, void* trace_userp); + +/// Create a new inference trace object. The caller takes ownership of +/// the TRITONSERVER_InferenceTrace object and must call +/// TRITONSERVER_InferenceTraceDelete to release the object. +/// +/// The timeline and tensor activity callback function will be called to report +/// activity for 'trace' as well as for any child traces that are spawned by +/// 'trace', and so the activity callback must check the trace object +/// to determine specifically what activity is being reported. +/// +/// The release callback is called for both 'trace' and for any child +/// traces spawned by 'trace'. +/// +/// \param trace Returns the new inference trace object. +/// \param level The tracing level. +/// \param parent_id The parent trace id for this trace. A value of 0 +/// indicates that there is not parent trace. +/// \param activity_fn The callback function where timeline activity for the +/// trace is reported. +/// \param tensor_activity_fn The callback function where tensor activity for +/// the trace is reported. +/// \param release_fn The callback function called when all activity +/// is complete for the trace. +/// \param trace_userp User-provided pointer that is delivered to +/// the activity and release callback functions. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceTraceTensorNew( + TRITONSERVER_InferenceTrace** trace, TRITONSERVER_InferenceTraceLevel level, + uint64_t parent_id, TRITONSERVER_InferenceTraceActivityFn_t activity_fn, + TRITONSERVER_InferenceTraceTensorActivityFn_t tensor_activity_fn, + TRITONSERVER_InferenceTraceReleaseFn_t release_fn, void* trace_userp); + +/// Delete a trace object. +/// +/// \param trace The trace object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceTraceDelete( + TRITONSERVER_InferenceTrace* trace); + +/// Get the id associated with a trace. Every trace is assigned an id +/// that is unique across all traces created for a Triton server. +/// +/// \param trace The trace. +/// \param id Returns the id associated with the trace. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceTraceId( + TRITONSERVER_InferenceTrace* trace, uint64_t* id); + +/// Get the parent id associated with a trace. The parent id indicates +/// a parent-child relationship between two traces. A parent id value +/// of 0 indicates that there is no parent trace. +/// +/// \param trace The trace. +/// \param id Returns the parent id associated with the trace. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceTraceParentId( + TRITONSERVER_InferenceTrace* trace, uint64_t* parent_id); + +/// Get the name of the model associated with a trace. The caller does +/// not own the returned string and must not modify or delete it. The +/// lifetime of the returned string extends only as long as 'trace'. +/// +/// \param trace The trace. +/// \param model_name Returns the name of the model associated with +/// the trace. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceTraceModelName( + TRITONSERVER_InferenceTrace* trace, const char** model_name); + +/// Get the version of the model associated with a trace. +/// +/// \param trace The trace. +/// \param model_version Returns the version of the model associated +/// with the trace. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceTraceModelVersion( + TRITONSERVER_InferenceTrace* trace, int64_t* model_version); + +/// TRITONSERVER_InferenceRequest +/// +/// Object representing an inference request. The inference request +/// provides the meta-data and input tensor values needed for an +/// inference and returns the inference result meta-data and output +/// tensors. An inference request object can be modified and reused +/// multiple times. +/// + +/// Inference request flags. The enum values must be power-of-2 values. +typedef enum tritonserver_requestflag_enum { + TRITONSERVER_REQUEST_FLAG_SEQUENCE_START = 1, + TRITONSERVER_REQUEST_FLAG_SEQUENCE_END = 2 +} TRITONSERVER_RequestFlag; + +/// Inference request release flags. The enum values must be +/// power-of-2 values. +typedef enum tritonserver_requestreleaseflag_enum { + TRITONSERVER_REQUEST_RELEASE_ALL = 1 +} TRITONSERVER_RequestReleaseFlag; + +/// Inference response complete flags. The enum values must be +/// power-of-2 values. +typedef enum tritonserver_responsecompleteflag_enum { + TRITONSERVER_RESPONSE_COMPLETE_FINAL = 1 +} TRITONSERVER_ResponseCompleteFlag; + +/// Type for inference request release callback function. The callback +/// indicates what type of release is being performed on the request +/// and for some of these the callback function takes ownership of the +/// TRITONSERVER_InferenceRequest object. The 'userp' data is the data +/// provided as 'request_release_userp' in the call to +/// TRITONSERVER_InferenceRequestSetReleaseCallback. +/// +/// One or more flags will be specified when the callback is invoked, +/// and the callback must take the following actions: +/// +/// - TRITONSERVER_REQUEST_RELEASE_ALL: The entire inference request +/// is being released and ownership is passed to the callback +/// function. Triton will not longer access the 'request' object +/// itself nor any input tensor data associated with the +/// request. The callback should free or otherwise manage the +/// 'request' object and all associated tensor data. +/// +/// Note that currently TRITONSERVER_REQUEST_RELEASE_ALL should always +/// be set when the callback is invoked but in the future that may +/// change, so the callback should explicitly check for the flag +/// before taking ownership of the request object. +/// +typedef void (*TRITONSERVER_InferenceRequestReleaseFn_t)( + TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp); + +/// Type for callback function indicating that an inference response +/// has completed. The callback function takes ownership of the +/// TRITONSERVER_InferenceResponse object. The 'userp' data is the +/// data provided as 'response_userp' in the call to +/// TRITONSERVER_InferenceRequestSetResponseCallback. +/// +/// One or more flags may be specified when the callback is invoked: +/// +/// - TRITONSERVER_RESPONSE_COMPLETE_FINAL: Indicates that no more +/// responses will be generated for a given request (more +/// specifically, that no more responses will be generated for the +/// inference request that set this callback and 'userp'). When +/// this flag is set 'response' may be a response object or may be +/// nullptr. If 'response' is not nullptr, then 'response' is the +/// last response that Triton will produce for the request. If +/// 'response' is nullptr then Triton is indicating that no more +/// responses will be produced for the request. +typedef void (*TRITONSERVER_InferenceResponseCompleteFn_t)( + TRITONSERVER_InferenceResponse* response, const uint32_t flags, + void* userp); + +/// Create a new inference request object. +/// +/// \param inference_request Returns the new request object. +/// \param server the inference server object. +/// \param model_name The name of the model to use for the request. +/// \param model_version The version of the model to use for the +/// request. If -1 then the server will choose a version based on the +/// model's policy. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceRequestNew( + TRITONSERVER_InferenceRequest** inference_request, + TRITONSERVER_Server* server, const char* model_name, + const int64_t model_version); + +/// Delete an inference request object. +/// +/// \param inference_request The request object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceRequestDelete( + TRITONSERVER_InferenceRequest* inference_request); + +/// Get the ID for a request. The returned ID is owned by +/// 'inference_request' and must not be modified or freed by the +/// caller. +/// +/// \param inference_request The request object. +/// \param id Returns the ID. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceRequestId( + TRITONSERVER_InferenceRequest* inference_request, const char** id); + +/// Set the ID for a request. +/// +/// \param inference_request The request object. +/// \param id The ID. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceRequestSetId( + TRITONSERVER_InferenceRequest* inference_request, const char* id); + +/// Get the flag(s) associated with a request. On return 'flags' holds +/// a bitwise-or of all flag values, see TRITONSERVER_RequestFlag for +/// available flags. +/// +/// \param inference_request The request object. +/// \param flags Returns the flags. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceRequestFlags( + TRITONSERVER_InferenceRequest* inference_request, uint32_t* flags); + +/// Set the flag(s) associated with a request. 'flags' should hold a +/// bitwise-or of all flag values, see TRITONSERVER_RequestFlag for +/// available flags. +/// +/// \param inference_request The request object. +/// \param flags The flags. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceRequestSetFlags( + TRITONSERVER_InferenceRequest* inference_request, uint32_t flags); + +/// Get the correlation ID of the inference request as an unsigned integer. +/// Default is 0, which indicates that the request has no correlation ID. +/// If the correlation id associated with the inference request is a string, +/// this function will return a failure. The correlation ID is used +/// to indicate two or more inference request are related to each other. +/// How this relationship is handled by the inference server is determined by +/// the model's scheduling policy. +/// +/// \param inference_request The request object. +/// \param correlation_id Returns the correlation ID. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestCorrelationId( + TRITONSERVER_InferenceRequest* inference_request, uint64_t* correlation_id); + +/// Get the correlation ID of the inference request as a string. +/// Default is empty "", which indicates that the request has no correlation ID. +/// If the correlation id associated with the inference request is an unsigned +/// integer, then this function will return a failure. The correlation ID +/// is used to indicate two or more inference request are related to each other. +/// How this relationship is handled by the inference server is determined by +/// the model's scheduling policy. +/// +/// \param inference_request The request object. +/// \param correlation_id Returns the correlation ID. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestCorrelationIdString( + TRITONSERVER_InferenceRequest* inference_request, + const char** correlation_id); + +/// Set the correlation ID of the inference request to be an unsigned integer. +/// Default is 0, which indicates that the request has no correlation ID. +/// The correlation ID is used to indicate two or more inference request +/// are related to each other. How this relationship is handled by the +/// inference server is determined by the model's scheduling policy. +/// +/// \param inference_request The request object. +/// \param correlation_id The correlation ID. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestSetCorrelationId( + TRITONSERVER_InferenceRequest* inference_request, uint64_t correlation_id); + +/// Set the correlation ID of the inference request to be a string. +/// The correlation ID is used to indicate two or more inference +/// request are related to each other. How this relationship is +/// handled by the inference server is determined by the model's +/// scheduling policy. +/// +/// \param inference_request The request object. +/// \param correlation_id The correlation ID. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestSetCorrelationIdString( + TRITONSERVER_InferenceRequest* inference_request, + const char* correlation_id); + +/// Get the priority for a request. The default is 0 indicating that +/// the request does not specify a priority and so will use the +/// model's default priority. +/// +/// \param inference_request The request object. +/// \param priority Returns the priority level. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceRequestPriority( + TRITONSERVER_InferenceRequest* inference_request, uint32_t* priority); + +/// Set the priority for a request. The default is 0 indicating that +/// the request does not specify a priority and so will use the +/// model's default priority. +/// +/// \param inference_request The request object. +/// \param priority The priority level. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestSetPriority( + TRITONSERVER_InferenceRequest* inference_request, uint32_t priority); + +/// Get the timeout for a request, in microseconds. The default is 0 +/// which indicates that the request has no timeout. +/// +/// \param inference_request The request object. +/// \param timeout_us Returns the timeout, in microseconds. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestTimeoutMicroseconds( + TRITONSERVER_InferenceRequest* inference_request, uint64_t* timeout_us); + +/// Set the timeout for a request, in microseconds. The default is 0 +/// which indicates that the request has no timeout. +/// +/// \param inference_request The request object. +/// \param timeout_us The timeout, in microseconds. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestSetTimeoutMicroseconds( + TRITONSERVER_InferenceRequest* inference_request, uint64_t timeout_us); + +/// Add an input to a request. +/// +/// \param inference_request The request object. +/// \param name The name of the input. +/// \param datatype The type of the input. Valid type names are BOOL, +/// UINT8, UINT16, UINT32, UINT64, INT8, INT16, INT32, INT64, FP16, +/// FP32, FP64, and BYTES. +/// \param shape The shape of the input. +/// \param dim_count The number of dimensions of 'shape'. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceRequestAddInput( + TRITONSERVER_InferenceRequest* inference_request, const char* name, + const TRITONSERVER_DataType datatype, const int64_t* shape, + uint64_t dim_count); + +/// Add a raw input to a request. The name recognized by the model, data type +/// and shape of the input will be deduced from model configuration. +/// This function must be called at most once on request with no other input to +/// ensure the deduction is accurate. +/// +/// \param inference_request The request object. +/// \param name The name of the input. This name is only used as a reference +/// of the raw input in other Tritonserver APIs. It doesn't assoicate with the +/// name used in the model. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestAddRawInput( + TRITONSERVER_InferenceRequest* inference_request, const char* name); + +/// Remove an input from a request. +/// +/// \param inference_request The request object. +/// \param name The name of the input. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestRemoveInput( + TRITONSERVER_InferenceRequest* inference_request, const char* name); + +/// Remove all inputs from a request. +/// +/// \param inference_request The request object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestRemoveAllInputs( + TRITONSERVER_InferenceRequest* inference_request); + +/// Assign a buffer of data to an input. The buffer will be appended +/// to any existing buffers for that input. The 'inference_request' +/// object takes ownership of the buffer and so the caller should not +/// modify or free the buffer until that ownership is released by +/// 'inference_request' being deleted or by the input being removed +/// from 'inference_request'. +/// +/// \param inference_request The request object. +/// \param name The name of the input. +/// \param base The base address of the input data. +/// \param byte_size The size, in bytes, of the input data. +/// \param memory_type The memory type of the input data. +/// \param memory_type_id The memory type id of the input data. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestAppendInputData( + TRITONSERVER_InferenceRequest* inference_request, const char* name, + const void* base, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id); + +/// Assign a buffer of data to an input for execution on all model instances +/// with the specified host policy. The buffer will be appended to any existing +/// buffers for that input on all devices with this host policy. The +/// 'inference_request' object takes ownership of the buffer and so the caller +/// should not modify or free the buffer until that ownership is released by +/// 'inference_request' being deleted or by the input being removed from +/// 'inference_request'. If the execution is scheduled on a device that does not +/// have a input buffer specified using this function, then the input buffer +/// specified with TRITONSERVER_InferenceRequestAppendInputData will be used so +/// a non-host policy specific version of data must be added using that API. +/// \param inference_request The request object. +/// \param name The name of the input. +/// \param base The base address of the input data. +/// \param byte_size The size, in bytes, of the input data. +/// \param memory_type The memory type of the input data. +/// \param memory_type_id The memory type id of the input data. +/// \param host_policy_name All model instances executing with this host_policy +/// will use this input buffer for execution. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestAppendInputDataWithHostPolicy( + TRITONSERVER_InferenceRequest* inference_request, const char* name, + const void* base, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id, const char* host_policy_name); + +/// Assign a buffer of data to an input. The buffer will be appended +/// to any existing buffers for that input. The 'inference_request' +/// object takes ownership of the buffer and so the caller should not +/// modify or free the buffer until that ownership is released by +/// 'inference_request' being deleted or by the input being removed +/// from 'inference_request'. +/// +/// \param inference_request The request object. +/// \param name The name of the input. +/// \param base The base address of the input data. +/// \param buffer_attributes The buffer attrubutes of the input. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestAppendInputDataWithBufferAttributes( + TRITONSERVER_InferenceRequest* inference_request, const char* name, + const void* base, TRITONSERVER_BufferAttributes* buffer_attributes); + +/// Clear all input data from an input, releasing ownership of the +/// buffer(s) that were appended to the input with +/// TRITONSERVER_InferenceRequestAppendInputData or +/// TRITONSERVER_InferenceRequestAppendInputDataWithHostPolicy +/// \param inference_request The request object. +/// \param name The name of the input. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestRemoveAllInputData( + TRITONSERVER_InferenceRequest* inference_request, const char* name); + +/// Add an output request to an inference request. +/// +/// \param inference_request The request object. +/// \param name The name of the output. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestAddRequestedOutput( + TRITONSERVER_InferenceRequest* inference_request, const char* name); + +/// Remove an output request from an inference request. +/// +/// \param inference_request The request object. +/// \param name The name of the output. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestRemoveRequestedOutput( + TRITONSERVER_InferenceRequest* inference_request, const char* name); + +/// Remove all output requests from an inference request. +/// +/// \param inference_request The request object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestRemoveAllRequestedOutputs( + TRITONSERVER_InferenceRequest* inference_request); + +/// Set the release callback for an inference request. The release +/// callback is called by Triton to return ownership of the request +/// object. +/// +/// \param inference_request The request object. +/// \param request_release_fn The function called to return ownership +/// of the 'inference_request' object. +/// \param request_release_userp User-provided pointer that is +/// delivered to the 'request_release_fn' callback. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestSetReleaseCallback( + TRITONSERVER_InferenceRequest* inference_request, + TRITONSERVER_InferenceRequestReleaseFn_t request_release_fn, + void* request_release_userp); + +/// Set the allocator and response callback for an inference +/// request. The allocator is used to allocate buffers for any output +/// tensors included in responses that are produced for this +/// request. The response callback is called to return response +/// objects representing responses produced for this request. +/// +/// \param inference_request The request object. +/// \param response_allocator The TRITONSERVER_ResponseAllocator to use +/// to allocate buffers to hold inference results. +/// \param response_allocator_userp User-provided pointer that is +/// delivered to the response allocator's start and allocation functions. +/// \param response_fn The function called to deliver an inference +/// response for this request. +/// \param response_userp User-provided pointer that is delivered to +/// the 'response_fn' callback. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestSetResponseCallback( + TRITONSERVER_InferenceRequest* inference_request, + TRITONSERVER_ResponseAllocator* response_allocator, + void* response_allocator_userp, + TRITONSERVER_InferenceResponseCompleteFn_t response_fn, + void* response_userp); + +/// TRITONSERVER_InferenceResponse +/// +/// Object representing an inference response. The inference response +/// provides the meta-data and output tensor values calculated by the +/// inference. +/// + +/// Delete an inference response object. +/// +/// \param inference_response The response object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceResponseDelete( + TRITONSERVER_InferenceResponse* inference_response); + +/// Return the error status of an inference response. Return a +/// TRITONSERVER_Error object on failure, return nullptr on success. +/// The returned error object is owned by 'inference_response' and so +/// should not be deleted by the caller. +/// +/// \param inference_response The response object. +/// \return a TRITONSERVER_Error indicating the success or failure +/// status of the response. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceResponseError( + TRITONSERVER_InferenceResponse* inference_response); + +/// Get model used to produce a response. The caller does not own the +/// returned model name value and must not modify or delete it. The +/// lifetime of all returned values extends until 'inference_response' +/// is deleted. +/// +/// \param inference_response The response object. +/// \param model_name Returns the name of the model. +/// \param model_version Returns the version of the model. +/// this response. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceResponseModel( + TRITONSERVER_InferenceResponse* inference_response, const char** model_name, + int64_t* model_version); + +/// Get the ID of the request corresponding to a response. The caller +/// does not own the returned ID and must not modify or delete it. The +/// lifetime of all returned values extends until 'inference_response' +/// is deleted. +/// +/// \param inference_response The response object. +/// \param request_id Returns the ID of the request corresponding to +/// this response. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceResponseId( + TRITONSERVER_InferenceResponse* inference_response, + const char** request_id); + +/// Get the number of parameters available in the response. +/// +/// \param inference_response The response object. +/// \param count Returns the number of parameters. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceResponseParameterCount( + TRITONSERVER_InferenceResponse* inference_response, uint32_t* count); + +/// Get all information about a parameter. The caller does not own any +/// of the returned values and must not modify or delete them. The +/// lifetime of all returned values extends until 'inference_response' +/// is deleted. +/// +/// The 'vvalue' returns a void* pointer that must be cast +/// appropriately based on 'type'. For example: +/// +/// void* vvalue; +/// TRITONSERVER_ParameterType type; +/// TRITONSERVER_InferenceResponseParameter( +/// response, index, &name, &type, &vvalue); +/// switch (type) { +/// case TRITONSERVER_PARAMETER_BOOL: +/// bool value = *(reinterpret_cast(vvalue)); +/// ... +/// case TRITONSERVER_PARAMETER_INT: +/// int64_t value = *(reinterpret_cast(vvalue)); +/// ... +/// case TRITONSERVER_PARAMETER_STRING: +/// const char* value = reinterpret_cast(vvalue); +/// ... +/// +/// \param inference_response The response object. +/// \param index The index of the parameter, must be 0 <= index < +/// count, where 'count' is the value returned by +/// TRITONSERVER_InferenceResponseParameterCount. +/// \param name Returns the name of the parameter. +/// \param type Returns the type of the parameter. +/// \param vvalue Returns a pointer to the parameter value. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceResponseParameter( + TRITONSERVER_InferenceResponse* inference_response, const uint32_t index, + const char** name, TRITONSERVER_ParameterType* type, const void** vvalue); + +/// Get the number of outputs available in the response. +/// +/// \param inference_response The response object. +/// \param count Returns the number of output tensors. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceResponseOutputCount( + TRITONSERVER_InferenceResponse* inference_response, uint32_t* count); + +/// Get all information about an output tensor. The tensor data is +/// returned as the base pointer to the data and the size, in bytes, +/// of the data. The caller does not own any of the returned values +/// and must not modify or delete them. The lifetime of all returned +/// values extends until 'inference_response' is deleted. +/// +/// \param inference_response The response object. +/// \param index The index of the output tensor, must be 0 <= index < +/// count, where 'count' is the value returned by +/// TRITONSERVER_InferenceResponseOutputCount. +/// \param name Returns the name of the output. +/// \param datatype Returns the type of the output. +/// \param shape Returns the shape of the output. +/// \param dim_count Returns the number of dimensions of the returned +/// shape. +/// \param base Returns the tensor data for the output. +/// \param byte_size Returns the size, in bytes, of the data. +/// \param memory_type Returns the memory type of the data. +/// \param memory_type_id Returns the memory type id of the data. +/// \param userp The user-specified value associated with the buffer +/// in TRITONSERVER_ResponseAllocatorAllocFn_t. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_InferenceResponseOutput( + TRITONSERVER_InferenceResponse* inference_response, const uint32_t index, + const char** name, TRITONSERVER_DataType* datatype, const int64_t** shape, + uint64_t* dim_count, const void** base, size_t* byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id, + void** userp); + +/// Get a classification label associated with an output for a given +/// index. The caller does not own the returned label and must not +/// modify or delete it. The lifetime of all returned label extends +/// until 'inference_response' is deleted. +/// +/// \param inference_response The response object. +/// \param index The index of the output tensor, must be 0 <= index < +/// count, where 'count' is the value returned by +/// TRITONSERVER_InferenceResponseOutputCount. +/// \param class_index The index of the class. +/// \param name Returns the label corresponding to 'class_index' or +/// nullptr if no label. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceResponseOutputClassificationLabel( + TRITONSERVER_InferenceResponse* inference_response, const uint32_t index, + const size_t class_index, const char** label); + +/// TRITONSERVER_BufferAttributes +/// +/// API to create, modify, or retrieve attributes associated with a buffer. +/// + +/// Create a new buffer attributes object. The caller takes ownership of +/// the TRITONSERVER_BufferAttributes object and must call +/// TRITONSERVER_BufferAttributesDelete to release the object. +/// +/// \param buffer_attributes Returns the new buffer attributes object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_BufferAttributesNew( + TRITONSERVER_BufferAttributes** buffer_attributes); + +/// Delete a buffer attributes object. +/// +/// \param buffer_attributes The buffer_attributes object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_BufferAttributesDelete( + TRITONSERVER_BufferAttributes* buffer_attributes); + +/// Set the memory type id field of the buffer attributes. +/// +/// \param buffer_attributes The buffer attributes object. +/// \param memory_type_id Memory type id to assign to the buffer attributes +/// object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_BufferAttributesSetMemoryTypeId( + TRITONSERVER_BufferAttributes* buffer_attributes, int64_t memory_type_id); + +/// Set the memory type field of the buffer attributes. +/// +/// \param buffer_attributes The buffer attributes object. +/// \param memory_type Memory type to assign to the buffer attributes object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_BufferAttributesSetMemoryType( + TRITONSERVER_BufferAttributes* buffer_attributes, + TRITONSERVER_MemoryType memory_type); + +/// Set the CudaIpcHandle field of the buffer attributes. +/// +/// \param buffer_attributes The buffer attributes object. +/// \param cuda_ipc_handle The CudaIpcHandle to assign to the buffer attributes +/// object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_BufferAttributesSetCudaIpcHandle( + TRITONSERVER_BufferAttributes* buffer_attributes, void* cuda_ipc_handle); + +/// Set the byte size field of the buffer attributes. +/// +/// \param buffer_attributes The buffer attributes object. +/// \param byte_size Byte size to assign to the buffer attributes object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_BufferAttributesSetByteSize( + TRITONSERVER_BufferAttributes* buffer_attributes, size_t byte_size); + +/// Get the memory type id field of the buffer attributes. +/// +/// \param buffer_attributes The buffer attributes object. +/// \param memory_type_id Returns the memory type id associated with the buffer +/// attributes object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_BufferAttributesMemoryTypeId( + TRITONSERVER_BufferAttributes* buffer_attributes, int64_t* memory_type_id); + +/// Get the memory type field of the buffer attributes. +/// +/// \param buffer_attributes The buffer attributes object. +/// \param memory_type Returns the memory type associated with the buffer +/// attributes object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_BufferAttributesMemoryType( + TRITONSERVER_BufferAttributes* buffer_attributes, + TRITONSERVER_MemoryType* memory_type); + +/// Get the CudaIpcHandle field of the buffer attributes object. +/// +/// \param buffer_attributes The buffer attributes object. +/// \param cuda_ipc_handle Returns the memory type associated with the buffer +/// attributes object. If the cudaIpcHandle does not exist for the buffer, +/// nullptr will be returned. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_BufferAttributesCudaIpcHandle( + TRITONSERVER_BufferAttributes* buffer_attributes, void** cuda_ipc_handle); + +/// Get the byte size field of the buffer attributes. +/// +/// \param buffer_attributes The buffer attributes object. +/// \param byte_size Returns the byte size associated with the buffer attributes +/// object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_BufferAttributesByteSize( + TRITONSERVER_BufferAttributes* buffer_attributes, size_t* byte_size); + + +/// TRITONSERVER_ServerOptions +/// +/// Options to use when creating an inference server. +/// + +/// Model control modes +typedef enum tritonserver_modelcontrolmode_enum { + TRITONSERVER_MODEL_CONTROL_NONE, + TRITONSERVER_MODEL_CONTROL_POLL, + TRITONSERVER_MODEL_CONTROL_EXPLICIT +} TRITONSERVER_ModelControlMode; + +/// Rate limit modes +typedef enum tritonserver_ratelimitmode_enum { + TRITONSERVER_RATE_LIMIT_OFF, + TRITONSERVER_RATE_LIMIT_EXEC_COUNT +} TRITONSERVER_RateLimitMode; + +/// Create a new server options object. The caller takes ownership of +/// the TRITONSERVER_ServerOptions object and must call +/// TRITONSERVER_ServerOptionsDelete to release the object. +/// +/// \param options Returns the new server options object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerOptionsNew( + TRITONSERVER_ServerOptions** options); + +/// Delete a server options object. +/// +/// \param options The server options object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerOptionsDelete( + TRITONSERVER_ServerOptions* options); + +/// Set the textual ID for the server in a server options. The ID is a +/// name that identifies the server. +/// +/// \param options The server options object. +/// \param server_id The server identifier. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerOptionsSetServerId( + TRITONSERVER_ServerOptions* options, const char* server_id); + +/// Set the model repository path in a server options. The path must be +/// the full absolute path to the model repository. This function can be called +/// multiple times with different paths to set multiple model repositories. +/// Note that if a model is not unique across all model repositories +/// at any time, the model will not be available. +/// +/// \param options The server options object. +/// \param model_repository_path The full path to the model repository. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetModelRepositoryPath( + TRITONSERVER_ServerOptions* options, const char* model_repository_path); + +/// Set the model control mode in a server options. For each mode the models +/// will be managed as the following: +/// +/// TRITONSERVER_MODEL_CONTROL_NONE: the models in model repository will be +/// loaded on startup. After startup any changes to the model repository will +/// be ignored. Calling TRITONSERVER_ServerPollModelRepository will result in +/// an error. +/// +/// TRITONSERVER_MODEL_CONTROL_POLL: the models in model repository will be +/// loaded on startup. The model repository can be polled periodically using +/// TRITONSERVER_ServerPollModelRepository and the server will load, unload, +/// and updated models according to changes in the model repository. +/// +/// TRITONSERVER_MODEL_CONTROL_EXPLICIT: the models in model repository will +/// not be loaded on startup. The corresponding model control APIs must be +/// called to load / unload a model in the model repository. +/// +/// \param options The server options object. +/// \param mode The mode to use for the model control. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetModelControlMode( + TRITONSERVER_ServerOptions* options, TRITONSERVER_ModelControlMode mode); + +/// Set the model to be loaded at startup in a server options. The model must be +/// present in one, and only one, of the specified model repositories. +/// This function can be called multiple times with different model name +/// to set multiple startup models. +/// Note that it only takes affect on TRITONSERVER_MODEL_CONTROL_EXPLICIT mode. +/// +/// \param options The server options object. +/// \param mode_name The name of the model to load on startup. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetStartupModel( + TRITONSERVER_ServerOptions* options, const char* model_name); + +/// Enable or disable strict model configuration handling in a server +/// options. +/// +/// \param options The server options object. +/// \param strict True to enable strict model configuration handling, +/// false to disable. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetStrictModelConfig( + TRITONSERVER_ServerOptions* options, bool strict); + +/// Set the rate limit mode in a server options. +/// +/// TRITONSERVER_RATE_LIMIT_EXEC_COUNT: The rate limiting prioritizes the +/// inference execution using the number of times each instance has got a +/// chance to run. The execution gets to run only when its resource +/// constraints are satisfied. +/// +/// TRITONSERVER_RATE_LIMIT_OFF: The rate limiting is turned off and the +/// inference gets executed whenever an instance is available. +/// +/// \param options The server options object. +/// \param mode The mode to use for the rate limiting. By default, execution +/// count is used to determine the priorities. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetRateLimiterMode( + TRITONSERVER_ServerOptions* options, TRITONSERVER_RateLimitMode mode); + +/// Add resource count for rate limiting. +/// +/// \param options The server options object. +/// \param name The name of the resource. +/// \param count The count of the resource. +/// \param device The device identifier for the resource. A value of -1 +/// indicates that the specified number of resources are available on every +/// device. The device value is ignored for a global resource. The server +/// will use the rate limiter configuration specified for instance groups +/// in model config to determine whether resource is global. In case of +/// conflicting resource type in different model configurations, server +/// will raise an appropriate error while loading model. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsAddRateLimiterResource( + TRITONSERVER_ServerOptions* options, const char* resource_name, + const size_t resource_count, const int device); + +/// Set the total pinned memory byte size that the server can allocate +/// in a server options. The pinned memory pool will be shared across +/// Triton itself and the backends that use +/// TRITONBACKEND_MemoryManager to allocate memory. +/// +/// \param options The server options object. +/// \param size The pinned memory pool byte size. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetPinnedMemoryPoolByteSize( + TRITONSERVER_ServerOptions* options, uint64_t size); + +/// Set the total CUDA memory byte size that the server can allocate +/// on given GPU device in a server options. The pinned memory pool +/// will be shared across Triton itself and the backends that use +/// TRITONBACKEND_MemoryManager to allocate memory. +/// +/// \param options The server options object. +/// \param gpu_device The GPU device to allocate the memory pool. +/// \param size The CUDA memory pool byte size. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetCudaMemoryPoolByteSize( + TRITONSERVER_ServerOptions* options, int gpu_device, uint64_t size); + +/// Set the total response cache byte size that the server can allocate in CPU +/// memory. The response cache will be shared across all inference requests and +/// across all models. +/// +/// \param options The server options object. +/// \param size The total response cache byte size. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetResponseCacheByteSize( + TRITONSERVER_ServerOptions* options, uint64_t size); + +/// Set the minimum support CUDA compute capability in a server +/// options. +/// +/// \param options The server options object. +/// \param cc The minimum CUDA compute capability. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetMinSupportedComputeCapability( + TRITONSERVER_ServerOptions* options, double cc); + +/// Enable or disable exit-on-error in a server options. +/// +/// \param options The server options object. +/// \param exit True to enable exiting on intialization error, false +/// to continue. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetExitOnError( + TRITONSERVER_ServerOptions* options, bool exit); + +/// Enable or disable strict readiness handling in a server options. +/// +/// \param options The server options object. +/// \param strict True to enable strict readiness handling, false to +/// disable. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetStrictReadiness( + TRITONSERVER_ServerOptions* options, bool strict); + +/// Set the exit timeout, in seconds, for the server in a server +/// options. +/// +/// \param options The server options object. +/// \param timeout The exit timeout, in seconds. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetExitTimeout( + TRITONSERVER_ServerOptions* options, unsigned int timeout); + +/// Set the number of threads used in buffer manager in a server options. +/// +/// \param options The server options object. +/// \param thread_count The number of threads. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetBufferManagerThreadCount( + TRITONSERVER_ServerOptions* options, unsigned int thread_count); + +/// Set the number of threads to concurrently load models in a server options. +/// +/// \param options The server options object. +/// \param thread_count The number of threads. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetModelLoadThreadCount( + TRITONSERVER_ServerOptions* options, unsigned int thread_count); + +/// Provide a log output file. +/// +/// \param options The server options object. +/// \param file a string defining the file where the log outputs will be saved. +/// An empty string for the file name will cause triton to direct logging +/// facilities to the console +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerOptionsSetLogFile( + TRITONSERVER_ServerOptions* options, const char* file); + +/// Enable or disable info level logging. +/// +/// \param options The server options object. +/// \param log True to enable info logging, false to disable. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerOptionsSetLogInfo( + TRITONSERVER_ServerOptions* options, bool log); + +/// Enable or disable warning level logging. +/// +/// \param options The server options object. +/// \param log True to enable warning logging, false to disable. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerOptionsSetLogWarn( + TRITONSERVER_ServerOptions* options, bool log); + +/// Enable or disable error level logging. +/// +/// \param options The server options object. +/// \param log True to enable error logging, false to disable. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerOptionsSetLogError( + TRITONSERVER_ServerOptions* options, bool log); + +/// Set the logging format. +/// +/// \param options The server options object. +/// \param format The logging format. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetLogFormat( + TRITONSERVER_ServerOptions* options, const TRITONSERVER_LogFormat format); + +/// Set verbose logging level. Level zero disables verbose logging. +/// +/// \param options The server options object. +/// \param level The verbose logging level. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetLogVerbose( + TRITONSERVER_ServerOptions* options, int level); + +/// Enable or disable metrics collection in a server options. +/// +/// \param options The server options object. +/// \param metrics True to enable metrics, false to disable. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerOptionsSetMetrics( + TRITONSERVER_ServerOptions* options, bool metrics); + +/// Enable or disable GPU metrics collection in a server options. GPU +/// metrics are collected if both this option and +/// TRITONSERVER_ServerOptionsSetMetrics are true. +/// +/// \param options The server options object. +/// \param gpu_metrics True to enable GPU metrics, false to disable. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetGpuMetrics( + TRITONSERVER_ServerOptions* options, bool gpu_metrics); + +/// Enable or disable CPU metrics collection in a server options. CPU +/// metrics are collected if both this option and +/// TRITONSERVER_ServerOptionsSetMetrics are true. +/// +/// \param options The server options object. +/// \param cpu_metrics True to enable CPU metrics, false to disable. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetCpuMetrics( + TRITONSERVER_ServerOptions* options, bool cpu_metrics); + +/// Set the interval for metrics collection in a server options. +/// This is 2000 milliseconds by default. +/// +/// \param options The server options object. +/// \param metrics_interval_ms The time interval in ms between +/// successive metrics updates. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetMetricsInterval( + TRITONSERVER_ServerOptions* options, uint64_t metrics_interval_ms); + +/// Set the directory containing backend shared libraries. This +/// directory is searched last after the version and model directory +/// in the model repository when looking for the backend shared +/// library for a model. If the backend is named 'be' the directory +/// searched is 'backend_dir'/be/libtriton_be.so. +/// +/// \param options The server options object. +/// \param backend_dir The full path of the backend directory. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetBackendDirectory( + TRITONSERVER_ServerOptions* options, const char* backend_dir); + +/// Set the directory containing repository agent shared libraries. This +/// directory is searched when looking for the repository agent shared +/// library for a model. If the backend is named 'ra' the directory +/// searched is 'repoagent_dir'/ra/libtritonrepoagent_ra.so. +/// +/// \param options The server options object. +/// \param repoagent_dir The full path of the repository agent directory. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetRepoAgentDirectory( + TRITONSERVER_ServerOptions* options, const char* repoagent_dir); + +/// Specify the limit on memory usage as a fraction on the device identified by +/// 'kind' and 'device_id'. If model loading on the device is requested and the +/// current memory usage exceeds the limit, the load will be rejected. If not +/// specified, the limit will not be set. +/// +/// Currently support TRITONSERVER_INSTANCEGROUPKIND_GPU +/// +/// \param options The server options object. +/// \param kind The kind of the device. +/// \param device_id The id of the device. +/// \param fraction The limit on memory usage as a fraction +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetModelLoadDeviceLimit( + TRITONSERVER_ServerOptions* options, + const TRITONSERVER_InstanceGroupKind kind, const int device_id, + const double fraction); + +/// Set a configuration setting for a named backend in a server +/// options. +/// +/// \param options The server options object. +/// \param backend_name The name of the backend. +/// \param setting The name of the setting. +/// \param value The setting value. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetBackendConfig( + TRITONSERVER_ServerOptions* options, const char* backend_name, + const char* setting, const char* value); + +/// Set a host policy setting for a given policy name in a server options. +/// +/// \param options The server options object. +/// \param policy_name The name of the policy. +/// \param setting The name of the setting. +/// \param value The setting value. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetHostPolicy( + TRITONSERVER_ServerOptions* options, const char* policy_name, + const char* setting, const char* value); + +/// TRITONSERVER_Server +/// +/// An inference server. +/// + +/// Model batch flags. The enum values must be power-of-2 values. +typedef enum tritonserver_batchflag_enum { + TRITONSERVER_BATCH_UNKNOWN = 1, + TRITONSERVER_BATCH_FIRST_DIM = 2 +} TRITONSERVER_ModelBatchFlag; + +/// Model index flags. The enum values must be power-of-2 values. +typedef enum tritonserver_modelindexflag_enum { + TRITONSERVER_INDEX_FLAG_READY = 1 +} TRITONSERVER_ModelIndexFlag; + +/// Model transaction policy flags. The enum values must be +/// power-of-2 values. +typedef enum tritonserver_txn_property_flag_enum { + TRITONSERVER_TXN_ONE_TO_ONE = 1, + TRITONSERVER_TXN_DECOUPLED = 2 +} TRITONSERVER_ModelTxnPropertyFlag; + +/// Create a new server object. The caller takes ownership of the +/// TRITONSERVER_Server object and must call TRITONSERVER_ServerDelete +/// to release the object. +/// +/// \param server Returns the new inference server object. +/// \param options The inference server options object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerNew( + TRITONSERVER_Server** server, TRITONSERVER_ServerOptions* options); + +/// Delete a server object. If server is not already stopped it is +/// stopped before being deleted. +/// +/// \param server The inference server object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerDelete( + TRITONSERVER_Server* server); + +/// Stop a server object. A server can't be restarted once it is +/// stopped. +/// +/// \param server The inference server object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerStop( + TRITONSERVER_Server* server); + +/// Register a new model repository. Not available in polling mode. +/// +/// \param server The inference server object. +/// \param repository_path The full path to the model repository. +/// \param name_mapping List of name_mapping parameters. Each mapping has +/// the model directory name as its key, overriden model name as its value. +/// \param model_count Number of mappings provided. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerRegisterModelRepository( + TRITONSERVER_Server* server, const char* repository_path, + const TRITONSERVER_Parameter** name_mapping, const uint32_t mapping_count); + +/// Unregister a model repository. Not available in polling mode. +/// +/// \param server The inference server object. +/// \param repository_path The full path to the model repository. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerUnregisterModelRepository( + TRITONSERVER_Server* server, const char* repository_path); + +/// Check the model repository for changes and update server state +/// based on those changes. +/// +/// \param server The inference server object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerPollModelRepository(TRITONSERVER_Server* server); + +/// Is the server live? +/// +/// \param server The inference server object. +/// \param live Returns true if server is live, false otherwise. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerIsLive( + TRITONSERVER_Server* server, bool* live); + +/// Is the server ready? +/// +/// \param server The inference server object. +/// \param ready Returns true if server is ready, false otherwise. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerIsReady( + TRITONSERVER_Server* server, bool* ready); + +/// Is the model ready? +/// +/// \param server The inference server object. +/// \param model_name The name of the model to get readiness for. +/// \param model_version The version of the model to get readiness +/// for. If -1 then the server will choose a version based on the +/// model's policy. +/// \param ready Returns true if server is ready, false otherwise. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerModelIsReady( + TRITONSERVER_Server* server, const char* model_name, + const int64_t model_version, bool* ready); + +/// Get the batch properties of the model. The properties are +/// communicated by a flags value and an (optional) object returned by +/// 'voidp'. +/// +/// - TRITONSERVER_BATCH_UNKNOWN: Triton cannot determine the +/// batching properties of the model. This means that the model +/// does not support batching in any way that is useable by +/// Triton. The returned 'voidp' value is nullptr. +/// +/// - TRITONSERVER_BATCH_FIRST_DIM: The model supports batching +/// along the first dimension of every input and output +/// tensor. Triton schedulers that perform batching can +/// automatically batch inference requests along this dimension. +/// The returned 'voidp' value is nullptr. +/// +/// \param server The inference server object. +/// \param model_name The name of the model. +/// \param model_version The version of the model. If -1 then the +/// server will choose a version based on the model's policy. +/// \param flags Returns flags indicating the batch properties of the +/// model. +/// \param voidp If non-nullptr, returns a point specific to the +/// 'flags' value. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerModelBatchProperties( + TRITONSERVER_Server* server, const char* model_name, + const int64_t model_version, uint32_t* flags, void** voidp); + +/// Get the transaction policy of the model. The policy is +/// communicated by a flags value. +/// +/// - TRITONSERVER_TXN_ONE_TO_ONE: The model generates exactly +/// one response per request. +/// +/// - TRITONSERVER_TXN_DECOUPLED: The model may generate zero +/// to many responses per request. +/// +/// \param server The inference server object. +/// \param model_name The name of the model. +/// \param model_version The version of the model. If -1 then the +/// server will choose a version based on the model's policy. +/// \param txn_flags Returns flags indicating the transaction policy of the +/// model. +/// \param voidp If non-nullptr, returns a point specific to the 'flags' value. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerModelTransactionProperties( + TRITONSERVER_Server* server, const char* model_name, + const int64_t model_version, uint32_t* txn_flags, void** voidp); + +/// Get the metadata of the server as a TRITONSERVER_Message object. +/// The caller takes ownership of the message object and must call +/// TRITONSERVER_MessageDelete to release the object. +/// +/// \param server The inference server object. +/// \param server_metadata Returns the server metadata message. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerMetadata( + TRITONSERVER_Server* server, TRITONSERVER_Message** server_metadata); + +/// Get the metadata of a model as a TRITONSERVER_Message +/// object. The caller takes ownership of the message object and must +/// call TRITONSERVER_MessageDelete to release the object. +/// +/// \param server The inference server object. +/// \param model_name The name of the model. +/// \param model_version The version of the model. +/// If -1 then the server will choose a version based on the model's +/// policy. +/// \param model_metadata Returns the model metadata message. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerModelMetadata( + TRITONSERVER_Server* server, const char* model_name, + const int64_t model_version, TRITONSERVER_Message** model_metadata); + +/// Get the statistics of a model as a TRITONSERVER_Message +/// object. The caller takes ownership of the object and must call +/// TRITONSERVER_MessageDelete to release the object. +/// +/// \param server The inference server object. +/// \param model_name The name of the model. +/// If empty, then statistics for all available models will be returned, +/// and the server will choose a version based on those models' policies. +/// \param model_version The version of the model. If -1 then the +/// server will choose a version based on the model's policy. +/// \param model_stats Returns the model statistics message. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerModelStatistics( + TRITONSERVER_Server* server, const char* model_name, + const int64_t model_version, TRITONSERVER_Message** model_stats); + +/// Get the configuration of a model as a TRITONSERVER_Message object. +/// The caller takes ownership of the message object and must call +/// TRITONSERVER_MessageDelete to release the object. +/// +/// \param server The inference server object. +/// \param model_name The name of the model. +/// \param model_version The version of the model. If -1 then the +/// server will choose a version based on the model's policy. +/// \param config_version The model configuration will be returned in +/// a format matching this version. If the configuration cannot be +/// represented in the requested version's format then an error will +/// be returned. Currently only version 1 is supported. +/// \param model_config Returns the model config message. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerModelConfig( + TRITONSERVER_Server* server, const char* model_name, + const int64_t model_version, const uint32_t config_version, + TRITONSERVER_Message** model_config); + +/// Get the index of all unique models in the model repositories as a +/// TRITONSERVER_Message object. The caller takes ownership of the +/// message object and must call TRITONSERVER_MessageDelete to release +/// the object. +/// +/// If TRITONSERVER_INDEX_FLAG_READY is set in 'flags' only the models +/// that are loaded into the server and ready for inferencing are +/// returned. +/// +/// \param server The inference server object. +/// \param flags TRITONSERVER_ModelIndexFlag flags that control how to +/// collect the index. +/// \param model_index Return the model index message that holds the +/// index of all models contained in the server's model repository(s). +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerModelIndex( + TRITONSERVER_Server* server, uint32_t flags, + TRITONSERVER_Message** model_index); + +/// Load the requested model or reload the model if it is already +/// loaded. The function does not return until the model is loaded or +/// fails to load. Returned error indicates if model loaded +/// successfully or not. +/// +/// \param server The inference server object. +/// \param model_name The name of the model. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerLoadModel( + TRITONSERVER_Server* server, const char* model_name); + +/// Load the requested model or reload the model if it is already +/// loaded, with load parameters provided. The function does not return until +/// the model is loaded or fails to load. Returned error indicates if model +/// loaded successfully or not. +/// Currently the below parameter names are recognized: +/// - "config" : string parameter that contains a JSON representation of the +/// model configuration. This config will be used for loading the model instead +/// of the one in the model directory. +/// +/// \param server The inference server object. +/// \param model_name The name of the model. +/// \param parameters The array of load parameters. +/// \param parameter_count The number of parameters. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerLoadModelWithParameters( + TRITONSERVER_Server* server, const char* model_name, + const TRITONSERVER_Parameter** parameters, const uint64_t parameter_count); + +/// Unload the requested model. Unloading a model that is not loaded +/// on server has no affect and success code will be returned. +/// The function does not wait for the requested model to be fully unload +/// and success code will be returned. +/// Returned error indicates if model unloaded successfully or not. +/// +/// \param server The inference server object. +/// \param model_name The name of the model. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerUnloadModel( + TRITONSERVER_Server* server, const char* model_name); + +/// Unload the requested model, and also unload any dependent model that +/// was loaded along with the requested model (for example, the models composing +/// an ensemble). Unloading a model that is not loaded +/// on server has no affect and success code will be returned. +/// The function does not wait for the requested model and all dependent +/// models to be fully unload and success code will be returned. +/// Returned error indicates if model unloaded successfully or not. +/// +/// \param server The inference server object. +/// \param model_name The name of the model. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerUnloadModelAndDependents( + TRITONSERVER_Server* server, const char* model_name); + +/// Get the current metrics for the server. The caller takes ownership +/// of the metrics object and must call TRITONSERVER_MetricsDelete to +/// release the object. +/// +/// \param server The inference server object. +/// \param metrics Returns the metrics. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerMetrics( + TRITONSERVER_Server* server, TRITONSERVER_Metrics** metrics); + +/// Perform inference using the meta-data and inputs supplied by the +/// 'inference_request'. If the function returns success, then the +/// caller releases ownership of 'inference_request' and must not +/// access it in any way after this call, until ownership is returned +/// via the 'request_release_fn' callback registered in the request +/// object with TRITONSERVER_InferenceRequestSetReleaseCallback. +/// +/// The function unconditionally takes ownership of 'trace' and so the +/// caller must not access it in any way after this call (except in +/// the trace activity callbacks) until ownership is returned via the +/// trace's release_fn callback. +/// +/// Responses produced for this request are returned using the +/// allocator and callback registered with the request by +/// TRITONSERVER_InferenceRequestSetResponseCallback. +/// +/// \param server The inference server object. +/// \param inference_request The request object. +/// \param trace The trace object for this request, or nullptr if no +/// tracing. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_ServerInferAsync( + TRITONSERVER_Server* server, + TRITONSERVER_InferenceRequest* inference_request, + TRITONSERVER_InferenceTrace* trace); + +/// TRITONSERVER_MetricKind +/// +/// Types of metrics recognized by TRITONSERVER. +/// +typedef enum TRITONSERVER_metrickind_enum { + TRITONSERVER_METRIC_KIND_COUNTER, + TRITONSERVER_METRIC_KIND_GAUGE +} TRITONSERVER_MetricKind; + +/// Create a new metric family object. The caller takes ownership of the +/// TRITONSERVER_MetricFamily object and must call +/// TRITONSERVER_MetricFamilyDelete to release the object. +/// +/// \param family Returns the new metric family object. +/// \param kind The type of metric family to create. +/// \param name The name of the metric family seen when calling the metrics +/// endpoint. +/// \param description The description of the metric family seen when +/// calling the metrics endpoint. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_MetricFamilyNew( + TRITONSERVER_MetricFamily** family, const TRITONSERVER_MetricKind kind, + const char* name, const char* description); + +/// Delete a metric family object. +/// A TRITONSERVER_MetricFamily* object should be deleted AFTER its +/// corresponding TRITONSERVER_Metric* objects have been deleted. +/// Attempting to delete a family before its metrics will return an error. +/// +/// \param family The metric family object to delete. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_MetricFamilyDelete( + TRITONSERVER_MetricFamily* family); + +/// Create a new metric object. The caller takes ownership of the +/// TRITONSERVER_Metric object and must call +/// TRITONSERVER_MetricDelete to release the object. The caller is also +/// responsible for ownership of the labels passed in. Each label can be deleted +/// immediately after creating the metric with TRITONSERVER_ParameterDelete +/// if not re-using the labels. +/// +/// \param metric Returns the new metric object. +/// \param family The metric family to add this new metric to. +/// \param labels The array of labels to associate with this new metric. +/// \param label_count The number of labels. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_MetricNew( + TRITONSERVER_Metric** metric, TRITONSERVER_MetricFamily* family, + const TRITONSERVER_Parameter** labels, const uint64_t label_count); + +/// Delete a metric object. +/// All TRITONSERVER_Metric* objects should be deleted BEFORE their +/// corresponding TRITONSERVER_MetricFamily* objects have been deleted. +/// If a family is deleted before its metrics, an error will be returned. +/// +/// \param metric The metric object to delete. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_MetricDelete( + TRITONSERVER_Metric* metric); + +/// Get the current value of a metric object. +/// Supports metrics of kind TRITONSERVER_METRIC_KIND_COUNTER +/// and TRITONSERVER_METRIC_KIND_GAUGE, and returns +/// TRITONSERVER_ERROR_UNSUPPORTED for unsupported TRITONSERVER_MetricKind. +/// +/// \param metric The metric object to query. +/// \param value Returns the current value of the metric object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_MetricValue( + TRITONSERVER_Metric* metric, double* value); + +/// Increment the current value of metric by value. +/// Supports metrics of kind TRITONSERVER_METRIC_KIND_GAUGE for any value, +/// and TRITONSERVER_METRIC_KIND_COUNTER for non-negative values. Returns +/// TRITONSERVER_ERROR_UNSUPPORTED for unsupported TRITONSERVER_MetricKind +/// and TRITONSERVER_ERROR_INVALID_ARG for negative values on a +/// TRITONSERVER_METRIC_KIND_COUNTER metric. +/// +/// \param metric The metric object to update. +/// \param value The amount to increment the metric's value by. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_MetricIncrement( + TRITONSERVER_Metric* metric, double value); + +/// Set the current value of metric to value. +/// Supports metrics of kind TRITONSERVER_METRIC_KIND_GAUGE and returns +/// TRITONSERVER_ERROR_UNSUPPORTED for unsupported TRITONSERVER_MetricKind. +/// +/// \param metric The metric object to update. +/// \param value The amount to set metric's value to. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_MetricSet( + TRITONSERVER_Metric* metric, double value); + +/// Get the TRITONSERVER_MetricKind of metric and its corresponding family. +/// +/// \param metric The metric object to query. +/// \param kind Returns the TRITONSERVER_MetricKind of metric. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONSERVER_GetMetricKind( + TRITONSERVER_Metric* metric, TRITONSERVER_MetricKind* kind); + +#ifdef __cplusplus +} +#endif diff --git a/3rdparty/core-r22.12/src/backend_config.cc b/3rdparty/core-r22.12/src/backend_config.cc new file mode 100644 index 0000000000000000000000000000000000000000..367475fb0fff2925d6fdedeb924bef8a3ded779f --- /dev/null +++ b/3rdparty/core-r22.12/src/backend_config.cc @@ -0,0 +1,225 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "backend_config.h" + +#include "status.h" +#include "triton/common/logging.h" +#include "triton/common/model_config.h" + +namespace triton { namespace core { + +namespace { + +Status +GetTFSpecializedBackendName( + const triton::common::BackendCmdlineConfigMap& config_map, + std::string* specialized_name) +{ + std::string tf_version_str = "2"; + const auto& itr = config_map.find("tensorflow"); + if (itr != config_map.end()) { + if (BackendConfiguration(itr->second, "version", &tf_version_str).IsOk()) { + if ((tf_version_str != "1") && (tf_version_str != "2")) { + return Status( + Status::Code::INVALID_ARG, + "unexpected TensorFlow library version '" + tf_version_str + + "', expects 1 or 2."); + } + } + } + + *specialized_name += tf_version_str; + + return Status::Success; +} +} // namespace + +Status +BackendConfiguration( + const triton::common::BackendCmdlineConfig& config, const std::string& key, + std::string* val) +{ + for (const auto& pr : config) { + if (pr.first == key) { + *val = pr.second; + return Status::Success; + } + } + + return Status( + Status::Code::INTERNAL, + std::string("unable to find common backend configuration for '") + key + + "'"); +} + +Status +BackendConfigurationParseStringToDouble(const std::string& str, double* val) +{ + try { + *val = std::stod(str); + } + catch (...) { + return Status( + Status::Code::INTERNAL, + "unable to parse common backend configuration as double"); + } + + return Status::Success; +} + +Status +BackendConfigurationParseStringToBool(const std::string& str, bool* val) +{ + try { + std::string lowercase_str{str}; + std::transform( + lowercase_str.begin(), lowercase_str.end(), lowercase_str.begin(), + [](unsigned char c) { return std::tolower(c); }); + *val = (lowercase_str == "true"); + } + catch (...) { + return Status( + Status::Code::INTERNAL, + "unable to parse common backend configuration as bool"); + } + + return Status::Success; +} + +Status +BackendConfigurationGlobalBackendsDirectory( + const triton::common::BackendCmdlineConfigMap& config_map, std::string* dir) +{ + const auto& itr = config_map.find(std::string()); + if (itr == config_map.end()) { + return Status( + Status::Code::INTERNAL, + "unable to find global backends directory configuration"); + } + + RETURN_IF_ERROR(BackendConfiguration(itr->second, "backend-directory", dir)); + + return Status::Success; +} + +Status +BackendConfigurationMinComputeCapability( + const triton::common::BackendCmdlineConfigMap& config_map, double* mcc) +{ +#ifdef TRITON_ENABLE_GPU + *mcc = TRITON_MIN_COMPUTE_CAPABILITY; +#else + *mcc = 0; +#endif // TRITON_ENABLE_GPU + + const auto& itr = config_map.find(std::string()); + if (itr == config_map.end()) { + return Status( + Status::Code::INTERNAL, "unable to find common backend configuration"); + } + + std::string min_compute_capability_str; + RETURN_IF_ERROR(BackendConfiguration( + itr->second, "min-compute-capability", &min_compute_capability_str)); + RETURN_IF_ERROR( + BackendConfigurationParseStringToDouble(min_compute_capability_str, mcc)); + + return Status::Success; +} + +Status +BackendConfigurationAutoCompleteConfig( + const triton::common::BackendCmdlineConfigMap& config_map, bool* acc) +{ + const auto& itr = config_map.find(std::string()); + if (itr == config_map.end()) { + return Status( + Status::Code::INTERNAL, "unable to find auto-complete configuration"); + } + + std::string auto_complete_config_str; + RETURN_IF_ERROR(BackendConfiguration( + itr->second, "auto-complete-config", &auto_complete_config_str)); + RETURN_IF_ERROR( + BackendConfigurationParseStringToBool(auto_complete_config_str, acc)); + + return Status::Success; +} + +Status +BackendConfigurationSpecializeBackendName( + const triton::common::BackendCmdlineConfigMap& config_map, + const std::string& backend_name, std::string* specialized_name) +{ + *specialized_name = backend_name; + if (backend_name == "tensorflow") { + RETURN_IF_ERROR(GetTFSpecializedBackendName(config_map, specialized_name)); + } + + return Status::Success; +} + +Status +BackendConfigurationBackendLibraryName( + const std::string& backend_name, std::string* libname) +{ +#ifdef _WIN32 + *libname = "triton_" + backend_name + ".dll"; +#else + *libname = "libtriton_" + backend_name + ".so"; +#endif + + return Status::Success; +} + +Status +BackendConfigurationModelLoadGpuFraction( + const triton::common::BackendCmdlineConfigMap& config_map, + const int device_id, double* memory_limit) +{ + *memory_limit = 1.0; + const auto& itr = config_map.find(std::string()); + if (itr == config_map.end()) { + return Status( + Status::Code::INTERNAL, + "unable to find global backends directory configuration"); + } + + static std::string key_prefix = "model-load-gpu-limit-device-"; + std::string memory_limit_str; + auto status = BackendConfiguration( + itr->second, key_prefix + std::to_string(device_id), &memory_limit_str); + // Allow missing key, default to 1.0 (no limit) if the limit is not specified + if (status.IsOk()) { + RETURN_IF_ERROR(BackendConfigurationParseStringToDouble( + memory_limit_str, memory_limit)); + } + + return Status::Success; +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/backend_config.h b/3rdparty/core-r22.12/src/backend_config.h new file mode 100644 index 0000000000000000000000000000000000000000..acd2a0c214030eb0b3b0708c69cc98b44344d135 --- /dev/null +++ b/3rdparty/core-r22.12/src/backend_config.h @@ -0,0 +1,77 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include "status.h" +#include "triton/common/model_config.h" + +namespace triton { namespace core { + +/// Get a key's string value from a backend configuration. +Status BackendConfiguration( + const triton::common::BackendCmdlineConfig& config, const std::string& key, + std::string* val); + +/// Convert a backend configuration string value into a double. +Status BackendConfigurationParseStringToDouble( + const std::string& str, double* val); + +/// Convert a backend configuration string value into a bool. +Status BackendConfigurationParseStringToBool(const std::string& str, bool* val); + +/// Get the global backends directory from the backend configuration. +Status BackendConfigurationGlobalBackendsDirectory( + const triton::common::BackendCmdlineConfigMap& config_map, + std::string* dir); + +/// Get the minimum compute capability from the backend configuration. +Status BackendConfigurationMinComputeCapability( + const triton::common::BackendCmdlineConfigMap& config_map, double* mcc); + +/// Get the model configuration auto-complete setting from the backend +/// configuration. +Status BackendConfigurationAutoCompleteConfig( + const triton::common::BackendCmdlineConfigMap& config_map, bool* acc); + +/// Convert a backend name to the specialized version of that name +/// based on the backend configuration. For example, "tensorflow" will +/// convert to either "tensorflow1" or "tensorflow2" depending on how +/// tritonserver is run. +Status BackendConfigurationSpecializeBackendName( + const triton::common::BackendCmdlineConfigMap& config_map, + const std::string& backend_name, std::string* specialized_name); + +/// Return the shared library name for a backend. +Status BackendConfigurationBackendLibraryName( + const std::string& backend_name, std::string* libname); + +/// Get GPU memory limit fraction for model loading +/// from the backend configuration. +Status BackendConfigurationModelLoadGpuFraction( + const triton::common::BackendCmdlineConfigMap& config_map, + const int device_id, double* memory_limit); + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/backend_manager.cc b/3rdparty/core-r22.12/src/backend_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..265202b241703c7e58aabc0534c7d9ef5fdb2b1f --- /dev/null +++ b/3rdparty/core-r22.12/src/backend_manager.cc @@ -0,0 +1,383 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "backend_manager.h" + +#include "backend_memory_manager.h" +#include "server_message.h" +#include "shared_library.h" +#include "triton/common/logging.h" + +// For unknown reason, windows will not export the TRITONBACKEND_* +// functions declared with dllexport in tritonbackend.h. To get those +// functions exported it is (also?) necessary to mark the definitions +// in this file with dllexport as well. +#if defined(_MSC_VER) +#define TRITONAPI_DECLSPEC __declspec(dllexport) +#elif defined(__GNUC__) +#define TRITONAPI_DECLSPEC __attribute__((__visibility__("default"))) +#else +#define TRITONAPI_DECLSPEC +#endif + +namespace triton { namespace core { + +// +// TritonBackend +// +Status +TritonBackend::Create( + const std::string& name, const std::string& dir, const std::string& libpath, + const triton::common::BackendCmdlineConfig& backend_cmdline_config, + std::shared_ptr* backend) +{ + // Create the JSON representation of the backend configuration. + triton::common::TritonJson::Value backend_config_json( + triton::common::TritonJson::ValueType::OBJECT); + if (!backend_cmdline_config.empty()) { + triton::common::TritonJson::Value cmdline_json( + backend_config_json, triton::common::TritonJson::ValueType::OBJECT); + for (const auto& pr : backend_cmdline_config) { + RETURN_IF_ERROR(cmdline_json.AddString(pr.first.c_str(), pr.second)); + } + + RETURN_IF_ERROR( + backend_config_json.Add("cmdline", std::move(cmdline_json))); + } + + TritonServerMessage backend_config(backend_config_json); + + auto local_backend = std::shared_ptr( + new TritonBackend(name, dir, libpath, backend_config)); + + // Load the library and initialize all the entrypoints + RETURN_IF_ERROR(local_backend->LoadBackendLibrary()); + + // Backend initialization is optional... The TRITONBACKEND_Backend + // object is this TritonBackend object. We must set set shared + // library path to point to the backend directory in case the + // backend library attempts to load additional shared libaries. + if (local_backend->backend_init_fn_ != nullptr) { + std::unique_ptr slib; + RETURN_IF_ERROR(SharedLibrary::Acquire(&slib)); + RETURN_IF_ERROR(slib->SetLibraryDirectory(local_backend->dir_)); + + TRITONSERVER_Error* err = local_backend->backend_init_fn_( + reinterpret_cast(local_backend.get())); + + RETURN_IF_ERROR(slib->ResetLibraryDirectory()); + RETURN_IF_TRITONSERVER_ERROR(err); + } + + local_backend->UpdateAttributes(); + + *backend = std::move(local_backend); + return Status::Success; +} + +Status +TritonBackend::UpdateAttributes() +{ + if (backend_attri_fn_ == nullptr) { + return Status::Success; + } + + // Create an Attribute object for the backend to fill, note that it copies + // some fields from 'attributes_' while the others use default value. This + // is an ad hoc way to determine whether the attribute is set by the backend + // and keep / update current value. + Attribute latest; + latest.exec_policy_ = attributes_.exec_policy_; + RETURN_IF_TRITONSERVER_ERROR(backend_attri_fn_( + reinterpret_cast(this), + reinterpret_cast(&latest))); + + // Update attributes that were set + attributes_.exec_policy_ = latest.exec_policy_; + if (!latest.preferred_groups_.empty()) { + attributes_.preferred_groups_ = latest.preferred_groups_; + } + return Status::Success; +} + +TritonBackend::TritonBackend( + const std::string& name, const std::string& dir, const std::string& libpath, + const TritonServerMessage& backend_config) + : name_(name), dir_(dir), libpath_(libpath), + backend_config_(backend_config), state_(nullptr) +{ + ClearHandles(); +} + +TritonBackend::~TritonBackend() +{ + LOG_VERBOSE(1) << "unloading backend '" << name_ << "'"; + + // Backend finalization is optional... The TRITONBACKEND_Backend + // object is this TritonBackend object. + if (backend_fini_fn_ != nullptr) { + LOG_TRITONSERVER_ERROR( + backend_fini_fn_(reinterpret_cast(this)), + "failed finalizing backend"); + } + + ClearHandles(); +} + +void +TritonBackend::ClearHandles() +{ + dlhandle_ = nullptr; + backend_init_fn_ = nullptr; + backend_fini_fn_ = nullptr; + backend_attri_fn_ = nullptr; + model_init_fn_ = nullptr; + model_fini_fn_ = nullptr; + inst_init_fn_ = nullptr; + inst_fini_fn_ = nullptr; + inst_exec_fn_ = nullptr; +} + +Status +TritonBackend::LoadBackendLibrary() +{ + TritonBackendInitFn_t bifn; + TritonBackendFiniFn_t bffn; + TritonBackendAttriFn_t bafn; + TritonModelInitFn_t mifn; + TritonModelFiniFn_t mffn; + TritonModelInstanceInitFn_t iifn; + TritonModelInstanceFiniFn_t iffn; + TritonModelInstanceExecFn_t iefn; + + { + std::unique_ptr slib; + RETURN_IF_ERROR(SharedLibrary::Acquire(&slib)); + + RETURN_IF_ERROR(slib->OpenLibraryHandle(libpath_, &dlhandle_)); + + // Backend initialize and finalize functions, optional + RETURN_IF_ERROR(slib->GetEntrypoint( + dlhandle_, "TRITONBACKEND_Initialize", true /* optional */, + reinterpret_cast(&bifn))); + RETURN_IF_ERROR(slib->GetEntrypoint( + dlhandle_, "TRITONBACKEND_Finalize", true /* optional */, + reinterpret_cast(&bffn))); + // Backend attribute function, optional + RETURN_IF_ERROR(slib->GetEntrypoint( + dlhandle_, "TRITONBACKEND_GetBackendAttribute", true /* optional */, + reinterpret_cast(&bafn))); + + // Model initialize and finalize functions, optional + RETURN_IF_ERROR(slib->GetEntrypoint( + dlhandle_, "TRITONBACKEND_ModelInitialize", true /* optional */, + reinterpret_cast(&mifn))); + RETURN_IF_ERROR(slib->GetEntrypoint( + dlhandle_, "TRITONBACKEND_ModelFinalize", true /* optional */, + reinterpret_cast(&mffn))); + + // Model instance initialize and finalize functions, optional + RETURN_IF_ERROR(slib->GetEntrypoint( + dlhandle_, "TRITONBACKEND_ModelInstanceInitialize", true /* optional */, + reinterpret_cast(&iifn))); + RETURN_IF_ERROR(slib->GetEntrypoint( + dlhandle_, "TRITONBACKEND_ModelInstanceFinalize", true /* optional */, + reinterpret_cast(&iffn))); + + // Model instance execute function, required + RETURN_IF_ERROR(slib->GetEntrypoint( + dlhandle_, "TRITONBACKEND_ModelInstanceExecute", false /* optional */, + reinterpret_cast(&iefn))); + } + + backend_init_fn_ = bifn; + backend_fini_fn_ = bffn; + backend_attri_fn_ = bafn; + model_init_fn_ = mifn; + model_fini_fn_ = mffn; + inst_init_fn_ = iifn; + inst_fini_fn_ = iffn; + inst_exec_fn_ = iefn; + + return Status::Success; +} + +extern "C" { + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ApiVersion(uint32_t* major, uint32_t* minor) +{ + *major = TRITONBACKEND_API_VERSION_MAJOR; + *minor = TRITONBACKEND_API_VERSION_MINOR; + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_BackendName(TRITONBACKEND_Backend* backend, const char** name) +{ + TritonBackend* tb = reinterpret_cast(backend); + *name = tb->Name().c_str(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_BackendConfig( + TRITONBACKEND_Backend* backend, TRITONSERVER_Message** backend_config) +{ + TritonBackend* tb = reinterpret_cast(backend); + *backend_config = const_cast( + reinterpret_cast(&tb->BackendConfig())); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_BackendExecutionPolicy( + TRITONBACKEND_Backend* backend, TRITONBACKEND_ExecutionPolicy* policy) +{ + TritonBackend* tb = reinterpret_cast(backend); + *policy = tb->ExecutionPolicy(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_BackendSetExecutionPolicy( + TRITONBACKEND_Backend* backend, TRITONBACKEND_ExecutionPolicy policy) +{ + TritonBackend* tb = reinterpret_cast(backend); + tb->SetExecutionPolicy(policy); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_BackendArtifacts( + TRITONBACKEND_Backend* backend, TRITONBACKEND_ArtifactType* artifact_type, + const char** location) +{ + TritonBackend* tb = reinterpret_cast(backend); + *artifact_type = TRITONBACKEND_ARTIFACT_FILESYSTEM; + *location = tb->Directory().c_str(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_BackendMemoryManager( + TRITONBACKEND_Backend* backend, TRITONBACKEND_MemoryManager** manager) +{ + static TritonMemoryManager gMemoryManager; + *manager = reinterpret_cast(&gMemoryManager); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_BackendState(TRITONBACKEND_Backend* backend, void** state) +{ + TritonBackend* tb = reinterpret_cast(backend); + *state = tb->State(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_BackendSetState(TRITONBACKEND_Backend* backend, void* state) +{ + TritonBackend* tb = reinterpret_cast(backend); + tb->SetState(state); + return nullptr; // success +} + +} // extern C + +// +// TritonBackendManager +// + +static std::weak_ptr backend_manager_; +static std::mutex mu_; + +Status +TritonBackendManager::Create(std::shared_ptr* manager) +{ + std::lock_guard lock(mu_); + + // If there is already a manager then we just use it... + *manager = backend_manager_.lock(); + if (*manager != nullptr) { + return Status::Success; + } + + manager->reset(new TritonBackendManager()); + backend_manager_ = *manager; + + return Status::Success; +} + +Status +TritonBackendManager::CreateBackend( + const std::string& name, const std::string& dir, const std::string& libpath, + const triton::common::BackendCmdlineConfig& backend_cmdline_config, + std::shared_ptr* backend) +{ + std::lock_guard lock(mu_); + + const auto& itr = backend_map_.find(libpath); + if (itr != backend_map_.end()) { + *backend = itr->second; + return Status::Success; + } + + RETURN_IF_ERROR(TritonBackend::Create( + name, dir, libpath, backend_cmdline_config, backend)); + backend_map_.insert({libpath, *backend}); + + return Status::Success; +} + +Status +TritonBackendManager::BackendState( + std::unique_ptr>>* + backend_state) +{ + std::lock_guard lock(mu_); + + std::unique_ptr>> + backend_state_map( + new std::unordered_map>); + for (const auto& backend_pair : backend_map_) { + auto& libpath = backend_pair.first; + auto backend = backend_pair.second; + + const char* backend_config; + size_t backend_config_size; + backend->BackendConfig().Serialize(&backend_config, &backend_config_size); + backend_state_map->insert( + {backend->Name(), std::vector{libpath, backend_config}}); + } + + *backend_state = std::move(backend_state_map); + + return Status::Success; +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/backend_manager.h b/3rdparty/core-r22.12/src/backend_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..089268c4221ddc001f01c80f313c3b22f3338d03 --- /dev/null +++ b/3rdparty/core-r22.12/src/backend_manager.h @@ -0,0 +1,174 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include +#include "constants.h" +#include "server_message.h" +#include "status.h" +#include "triton/common/model_config.h" +#include "tritonserver_apis.h" + +namespace triton { namespace core { + +// +// Proxy to a backend shared library. +// +class TritonBackend { + public: + struct Attribute { + Attribute() : exec_policy_(TRITONBACKEND_EXECUTION_BLOCKING) {} + TRITONBACKEND_ExecutionPolicy exec_policy_; + std::vector preferred_groups_; + }; + typedef TRITONSERVER_Error* (*TritonModelInitFn_t)( + TRITONBACKEND_Model* model); + typedef TRITONSERVER_Error* (*TritonModelFiniFn_t)( + TRITONBACKEND_Model* model); + typedef TRITONSERVER_Error* (*TritonModelInstanceInitFn_t)( + TRITONBACKEND_ModelInstance* instance); + typedef TRITONSERVER_Error* (*TritonModelInstanceFiniFn_t)( + TRITONBACKEND_ModelInstance* instance); + typedef TRITONSERVER_Error* (*TritonModelInstanceExecFn_t)( + TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests, + const uint32_t request_cnt); + + static Status Create( + const std::string& name, const std::string& dir, + const std::string& libpath, + const triton::common::BackendCmdlineConfig& backend_cmdline_config, + std::shared_ptr* backend); + ~TritonBackend(); + + const std::string& Name() const { return name_; } + const std::string& Directory() const { return dir_; } + const TritonServerMessage& BackendConfig() const { return backend_config_; } + const Attribute& BackendAttributes() const { return attributes_; } + + TRITONBACKEND_ExecutionPolicy ExecutionPolicy() const + { + return attributes_.exec_policy_; + } + void SetExecutionPolicy(const TRITONBACKEND_ExecutionPolicy policy) + { + attributes_.exec_policy_ = policy; + } + + void* State() { return state_; } + void SetState(void* state) { state_ = state; } + + TritonModelInitFn_t ModelInitFn() const { return model_init_fn_; } + TritonModelFiniFn_t ModelFiniFn() const { return model_fini_fn_; } + TritonModelInstanceInitFn_t ModelInstanceInitFn() const + { + return inst_init_fn_; + } + TritonModelInstanceFiniFn_t ModelInstanceFiniFn() const + { + return inst_fini_fn_; + } + TritonModelInstanceExecFn_t ModelInstanceExecFn() const + { + return inst_exec_fn_; + } + + private: + typedef TRITONSERVER_Error* (*TritonBackendInitFn_t)( + TRITONBACKEND_Backend* backend); + typedef TRITONSERVER_Error* (*TritonBackendFiniFn_t)( + TRITONBACKEND_Backend* backend); + typedef TRITONSERVER_Error* (*TritonBackendAttriFn_t)( + TRITONBACKEND_Backend* backend, + TRITONBACKEND_BackendAttribute* backend_attributes); + + TritonBackend( + const std::string& name, const std::string& dir, + const std::string& libpath, const TritonServerMessage& backend_config); + + void ClearHandles(); + Status LoadBackendLibrary(); + + Status UpdateAttributes(); + + // The name of the backend. + const std::string name_; + + // Full path to the directory holding backend shared library and + // other artifacts. + const std::string dir_; + + // Full path to the backend shared library. + const std::string libpath_; + + // Backend configuration as JSON + TritonServerMessage backend_config_; + + // backend attributes + Attribute attributes_; + + // dlopen / dlsym handles + void* dlhandle_; + TritonBackendInitFn_t backend_init_fn_; + TritonBackendFiniFn_t backend_fini_fn_; + TritonBackendAttriFn_t backend_attri_fn_; + TritonModelInitFn_t model_init_fn_; + TritonModelFiniFn_t model_fini_fn_; + TritonModelInstanceInitFn_t inst_init_fn_; + TritonModelInstanceFiniFn_t inst_fini_fn_; + TritonModelInstanceExecFn_t inst_exec_fn_; + + // Opaque state associated with the backend. + void* state_; +}; + +// +// Manage communication with Triton backends and their lifecycle. +// +class TritonBackendManager { + public: + static Status Create(std::shared_ptr* manager); + + Status CreateBackend( + const std::string& name, const std::string& dir, + const std::string& libpath, + const triton::common::BackendCmdlineConfig& backend_cmdline_config, + std::shared_ptr* backend); + + Status BackendState( + std::unique_ptr< + std::unordered_map>>* + backend_state); + + private: + DISALLOW_COPY_AND_ASSIGN(TritonBackendManager); + TritonBackendManager() = default; + std::unordered_map> backend_map_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/backend_memory_manager.cc b/3rdparty/core-r22.12/src/backend_memory_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..0266d2169fe52969490b8d0cc42f4e99df4344d0 --- /dev/null +++ b/3rdparty/core-r22.12/src/backend_memory_manager.cc @@ -0,0 +1,149 @@ +// Copyright 2020-2022, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "backend_memory_manager.h" + +#include "pinned_memory_manager.h" +#include "status.h" +#include "tritonserver_apis.h" + +#ifdef TRITON_ENABLE_GPU +#include +#include "cuda_memory_manager.h" +#endif // TRITON_ENABLE_GPU + +// For unknown reason, windows will not export the TRITONBACKEND_* +// functions declared with dllexport in tritonbackend.h. To get those +// functions exported it is (also?) necessary to mark the definitions +// in this file with dllexport as well. +#if defined(_MSC_VER) +#define TRITONAPI_DECLSPEC __declspec(dllexport) +#elif defined(__GNUC__) +#define TRITONAPI_DECLSPEC __attribute__((__visibility__("default"))) +#else +#define TRITONAPI_DECLSPEC +#endif + +namespace triton { namespace core { + +extern "C" { + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_MemoryManagerAllocate( + TRITONBACKEND_MemoryManager* manager, void** buffer, + const TRITONSERVER_MemoryType memory_type, const int64_t memory_type_id, + const uint64_t byte_size) +{ + switch (memory_type) { + case TRITONSERVER_MEMORY_GPU: +#ifdef TRITON_ENABLE_GPU + { + auto status = CudaMemoryManager::Alloc(buffer, byte_size, memory_type_id); + if (!status.IsOk()) { + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.ErrorCode()), + status.Message().c_str()); + } + break; + } +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "GPU memory allocation not supported"); +#endif // TRITON_ENABLE_GPU + + case TRITONSERVER_MEMORY_CPU_PINNED: +#ifdef TRITON_ENABLE_GPU + { + TRITONSERVER_MemoryType mt = memory_type; + auto status = PinnedMemoryManager::Alloc(buffer, byte_size, &mt, false); + if (!status.IsOk()) { + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.ErrorCode()), + status.Message().c_str()); + } + break; + } +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "Pinned memory allocation not supported"); +#endif // TRITON_ENABLE_GPU + + case TRITONSERVER_MEMORY_CPU: { + *buffer = malloc(byte_size); + if (*buffer == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNAVAILABLE, "CPU memory allocation failed"); + } + break; + } + } + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_MemoryManagerFree( + TRITONBACKEND_MemoryManager* manager, void* buffer, + const TRITONSERVER_MemoryType memory_type, const int64_t memory_type_id) +{ + switch (memory_type) { + case TRITONSERVER_MEMORY_GPU: { +#ifdef TRITON_ENABLE_GPU + auto status = CudaMemoryManager::Free(buffer, memory_type_id); + if (!status.IsOk()) { + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), + status.Message().c_str()); + } +#endif // TRITON_ENABLE_GPU + break; + } + + case TRITONSERVER_MEMORY_CPU_PINNED: { +#ifdef TRITON_ENABLE_GPU + auto status = PinnedMemoryManager::Free(buffer); + if (!status.IsOk()) { + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), + status.Message().c_str()); + } +#endif // TRITON_ENABLE_GPU + break; + } + + case TRITONSERVER_MEMORY_CPU: + free(buffer); + break; + } + + return nullptr; // success +} + +} // extern C + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/backend_memory_manager.h b/3rdparty/core-r22.12/src/backend_memory_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..5364e13bc944bae22c8dde4705323387709578a2 --- /dev/null +++ b/3rdparty/core-r22.12/src/backend_memory_manager.h @@ -0,0 +1,36 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +namespace triton { namespace core { + +// Currently there is just a global memory manager that is used for +// all backends and which simply forwards requests on to the core +// memory manager. +struct TritonMemoryManager { +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/backend_model.cc b/3rdparty/core-r22.12/src/backend_model.cc new file mode 100644 index 0000000000000000000000000000000000000000..2f838810589813c61969eade380fcaeff777a25a --- /dev/null +++ b/3rdparty/core-r22.12/src/backend_model.cc @@ -0,0 +1,1301 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "backend_model.h" + +#include +#include "backend_config.h" +#include "backend_model_instance.h" +#include "dynamic_batch_scheduler.h" +#include "filesystem.h" +#include "model_config_utils.h" +#include "numa_utils.h" +#include "sequence_batch_scheduler.h" +#include "sequence_state.h" +#include "server.h" +#include "server_message.h" +#include "shared_library.h" +#include "triton/common/logging.h" +#include "tritonserver_apis.h" + +// For unknown reason, windows will not export the TRITONBACKEND_* +// functions declared with dllexport in tritonbackend.h. To get those +// functions exported it is (also?) necessary to mark the definitions +// in this file with dllexport as well. +#if defined(_MSC_VER) +#define TRITONAPI_DECLSPEC __declspec(dllexport) +#elif defined(__GNUC__) +#define TRITONAPI_DECLSPEC __attribute__((__visibility__("default"))) +#else +#define TRITONAPI_DECLSPEC +#endif + +namespace triton { namespace core { + +Status +TritonModel::Create( + InferenceServer* server, const std::string& model_path, + const triton::common::BackendCmdlineConfigMap& backend_cmdline_config_map, + const triton::common::HostPolicyCmdlineConfigMap& host_policy_map, + const std::string& model_name, const int64_t version, + inference::ModelConfig model_config, const bool is_config_provided, + std::unique_ptr* model) +{ + model->reset(); + + // The model configuration must specify a backend. The name of the + // corresponding shared library must be libtriton_.so. + if (model_config.backend().empty()) { + return Status( + Status::Code::INVALID_ARG, + "must specify 'backend' for '" + model_config.name() + "'"); + } + + // Localize the content of the model repository corresponding to + // 'model_name'. This model holds a handle to the localized content + // so that it persists as long as the model is loaded. + std::shared_ptr localized_model_dir; + RETURN_IF_ERROR(LocalizePath(model_path, &localized_model_dir)); + + // Localize paths in backend model config + // [FIXME] Remove once a more permanent solution is implemented (DLIS-4211) + RETURN_IF_ERROR(LocalizePythonBackendExecutionEnvironmentPath( + model_path, &model_config, &localized_model_dir)); + + // Get some internal configuration values needed for initialization. + std::string backend_dir; + RETURN_IF_ERROR(BackendConfigurationGlobalBackendsDirectory( + backend_cmdline_config_map, &backend_dir)); + + bool auto_complete_config = false; + RETURN_IF_ERROR(BackendConfigurationAutoCompleteConfig( + backend_cmdline_config_map, &auto_complete_config)); + + double min_compute_capability = 0; + RETURN_IF_ERROR(BackendConfigurationMinComputeCapability( + backend_cmdline_config_map, &min_compute_capability)); + + std::string specialized_backend_name; + RETURN_IF_ERROR(BackendConfigurationSpecializeBackendName( + backend_cmdline_config_map, model_config.backend(), + &specialized_backend_name)); + + std::string backend_libname; + RETURN_IF_ERROR(BackendConfigurationBackendLibraryName( + specialized_backend_name, &backend_libname)); + + // Get the path to the backend shared library. Search path is + // version directory, model directory, global backend directory. + const auto localized_model_path = localized_model_dir->Path(); + const auto version_path = + JoinPath({localized_model_path, std::to_string(version)}); + const std::string global_path = + JoinPath({backend_dir, specialized_backend_name}); + const std::vector search_paths = { + version_path, localized_model_path, global_path}; + + std::string backend_libdir; + std::string backend_libpath; + for (const auto& path : search_paths) { + const auto full_path = JoinPath({path, backend_libname}); + bool exists = false; + RETURN_IF_ERROR(FileExists(full_path, &exists)); + if (exists) { + backend_libdir = path; + backend_libpath = full_path; + break; + } + } + + if (backend_libpath.empty()) { + return Status( + Status::Code::INVALID_ARG, "unable to find '" + backend_libname + + "' for model '" + model_config.name() + + "', searched: " + version_path + ", " + + model_path + ", " + global_path); + } + + // Resolve the global backend configuration with the specific backend + // configuration + triton::common::BackendCmdlineConfig config; + RETURN_IF_ERROR(ResolveBackendConfigs( + backend_cmdline_config_map, model_config.backend(), config)); + + RETURN_IF_ERROR(SetBackendConfigDefaults(config)); + + std::shared_ptr backend; + RETURN_IF_ERROR(server->BackendManager()->CreateBackend( + model_config.backend(), backend_libdir, backend_libpath, config, + &backend)); + + // Normalize backend-dependent config + { + const auto& attributes = backend->BackendAttributes(); + // [WIP] formalize config normalization / validation + RETURN_IF_ERROR(NormalizeInstanceGroup( + min_compute_capability, attributes.preferred_groups_, &model_config)); + RETURN_IF_ERROR( + ValidateInstanceGroup(model_config, min_compute_capability)); + } + + // Create and initialize the model. + std::unique_ptr local_model(new TritonModel( + server, localized_model_dir, backend, min_compute_capability, version, + model_config, auto_complete_config)); + + TritonModel* raw_local_model = local_model.get(); + + // Model initialization is optional... The TRITONBACKEND_Model + // object is this TritonModel object. We must set set shared library + // path to point to the backend directory in case the backend + // library attempts to load additional shared libaries. + if (backend->ModelInitFn() != nullptr) { + std::unique_ptr slib; + RETURN_IF_ERROR(SharedLibrary::Acquire(&slib)); + RETURN_IF_ERROR(slib->SetLibraryDirectory(backend->Directory())); + + TRITONSERVER_Error* err = backend->ModelInitFn()( + reinterpret_cast(raw_local_model)); + + RETURN_IF_ERROR(slib->ResetLibraryDirectory()); + RETURN_IF_TRITONSERVER_ERROR(err); + } + + // Initialize the model for Triton core usage + RETURN_IF_ERROR(local_model->Init(is_config_provided)); + + bool device_blocking = false; + if (local_model->backend_->ExecutionPolicy() == + TRITONBACKEND_EXECUTION_DEVICE_BLOCKING) { + if (model_config.has_sequence_batching()) { + LOG_INFO << "Overriding execution policy to " + "\"TRITONBACKEND_EXECUTION_BLOCKING\" for sequence model \"" + << model_config.name() << "\""; + } else { + device_blocking = true; + } + } + + // Create and initialize the model instances for this model. + RETURN_IF_ERROR(TritonModelInstance::CreateInstances( + raw_local_model, backend_cmdline_config_map, host_policy_map, + model_config, device_blocking)); + + RETURN_IF_ERROR(local_model->SetConfiguredScheduler()); + + *model = std::move(local_model); + return Status::Success; +} + +Status +TritonModel::ResolveBackendConfigs( + const triton::common::BackendCmdlineConfigMap& backend_cmdline_config_map, + const std::string& backend_name, + triton::common::BackendCmdlineConfig& config) +{ + const auto& global_itr = backend_cmdline_config_map.find(std::string()); + const auto& specific_itr = backend_cmdline_config_map.find(backend_name); + if (specific_itr == backend_cmdline_config_map.end() && + global_itr != backend_cmdline_config_map.end()) { + for (auto setting : global_itr->second) { + config.push_back(setting); + } + } else if ( + specific_itr != backend_cmdline_config_map.end() && + global_itr == backend_cmdline_config_map.end()) { + for (auto setting : specific_itr->second) { + config.push_back(setting); + } + } else if ( + specific_itr != backend_cmdline_config_map.end() && + global_itr != backend_cmdline_config_map.end()) { + triton::common::BackendCmdlineConfig global_backend_config = + global_itr->second; + triton::common::BackendCmdlineConfig specific_backend_config = + specific_itr->second; + + std::sort(global_backend_config.begin(), global_backend_config.end()); + std::sort(specific_backend_config.begin(), specific_backend_config.end()); + + size_t global_index = 0; + size_t specific_index = 0; + while (global_index < global_backend_config.size() && + specific_index < specific_backend_config.size()) { + auto& current_global_setting = global_backend_config.at(global_index); + auto& current_specific_setting = + specific_backend_config.at(specific_index); + if (current_specific_setting.first.compare( + current_global_setting.first) == 0) { + // specific setting overrides global setting + config.push_back(current_specific_setting); + ++global_index; + ++specific_index; + } else if ( + current_specific_setting.first.compare(current_global_setting.first) < + 0) { + config.push_back(current_specific_setting); + ++specific_index; + } else { + config.push_back(current_global_setting); + ++global_index; + } + } + + // add the rest of the global configs + if (global_index < global_backend_config.size()) { + auto& current_global_setting = global_backend_config.at(global_index); + config.push_back(current_global_setting); + } + + // add the rest of the specific settings + if (specific_index < specific_backend_config.size()) { + auto& current_specific_setting = + specific_backend_config.at(specific_index); + config.push_back(current_specific_setting); + } + } // else empty config + + return Status::Success; +} + + +const std::unordered_map backend_config_defaults( + {{"default-max-batch-size", "4"}}); + +Status +TritonModel::SetBackendConfigDefaults( + triton::common::BackendCmdlineConfig& config) +{ + auto backend_config_defaults_copy = backend_config_defaults; + + for (auto& setting : config) { + if (setting.first.compare("default-max-batch-size") == 0) { + LOG_VERBOSE(1) << "Found overwritten default setting: " << setting.first + << "," << setting.second; + backend_config_defaults_copy.erase(setting.first); + } + + if (backend_config_defaults_copy.empty()) { + break; + } + } + + // Anything left should be added to the config + for (const auto& default_setting : backend_config_defaults_copy) { + LOG_VERBOSE(1) << "Adding default backend config setting: " + << default_setting.first << "," << default_setting.second; + config.push_back( + std::make_pair(default_setting.first, default_setting.second)); + } + + return Status::Success; +} + +Status +TritonModel::AddInstance( + std::unique_ptr&& instance, const bool passive) +{ + if (passive) { + passive_instances_.emplace_back(std::move(instance)); + } else { + instances_.emplace_back(std::move(instance)); + } + + return Status::Success; +} + +Status +TritonModel::UpdateModelConfig( + const uint32_t config_version, TRITONSERVER_Message* updated_config_message) +{ + const char* buffer; + size_t byte_size; + RETURN_IF_TRITONSERVER_ERROR(TRITONSERVER_MessageSerializeToJson( + updated_config_message, &buffer, &byte_size)); + inference::ModelConfig updated_config; + RETURN_IF_ERROR( + JsonToModelConfig({buffer, byte_size}, config_version, &updated_config)); + auto config = Config(); + config.set_max_batch_size(updated_config.max_batch_size()); + + auto inputs_config = config.mutable_input(); + *inputs_config = updated_config.input(); + auto outputs_config = config.mutable_output(); + *outputs_config = updated_config.output(); + + if (!config.scheduling_choice_case()) { + if (updated_config.has_dynamic_batching()) { + auto dynamic_batching_config = config.mutable_dynamic_batching(); + *dynamic_batching_config = updated_config.dynamic_batching(); + } else if (updated_config.has_sequence_batching()) { + auto sequence_batching_config = config.mutable_sequence_batching(); + *sequence_batching_config = updated_config.sequence_batching(); + } else if (updated_config.has_ensemble_scheduling()) { + auto ensemble_scheduling_config = config.mutable_ensemble_scheduling(); + *ensemble_scheduling_config = updated_config.ensemble_scheduling(); + } // else do nothing + } else if ( + config.scheduling_choice_case() != + updated_config.scheduling_choice_case()) { + return Status( + triton::common::Error::Code::INTERNAL, + (std::string("Cannot update scheduling choice from ") + + std::to_string(config.scheduling_choice_case()) + std::string(" to ") + + std::to_string(config.scheduling_choice_case()) + + std::string(" when auto-completing.")) + .c_str()); + } // else do nothing + + // Need to normalize the model configuration for + // populating missing fields. + RETURN_IF_ERROR(NormalizeModelConfig(min_compute_capability_, &config)); + + RETURN_IF_ERROR(SetModelConfig(config)); + + return Status::Success; +} + +Status +TritonModel::SetConfiguredScheduler() +{ + std::unique_ptr scheduler; + + // Need to enforce equal shape batches (i.e. non-ragged batches) if + // the model 1) allows one or more variable-size input tensors that + // are not marked as 'allow_ragged_batch' or 2) has one or more + // shape-tensor inputs. This is not needed if all input shapes are + // non-variable and if there are no shape tensors... so we don't + // enable it in that case for efficiency reasons. + std::unordered_map enforce_equal_shape_tensors; + for (const auto input : config_.input()) { + if (input.is_shape_tensor()) { + enforce_equal_shape_tensors.insert({input.name(), true}); + } else if ( + !input.allow_ragged_batch() && + (triton::common::GetElementCount(input) == -1)) { + enforce_equal_shape_tensors.insert({input.name(), false}); + } + } + + // If 'sequence_batching' is configured, then use the SequenceBatchScheduler, + // otherwise use the default DynamicBatchScheduler. + if (config_.has_sequence_batching()) { + // Sequence batcher + RETURN_IF_ERROR(SequenceBatchScheduler::Create( + this, enforce_equal_shape_tensors, &scheduler)); + } else if (config_.has_dynamic_batching()) { + // Dynamic batcher + RETURN_IF_ERROR(DynamicBatchScheduler::Create( + this, nullptr, 0 /*nice*/, true /* dynamic_batching_enabled */, + config_.max_batch_size(), enforce_equal_shape_tensors, + config_.dynamic_batching(), + config_.response_cache().enable() /* response_cache_enable */, + &scheduler)); + } else { + // Default scheduler. Use dynamic batch scheduler (with batching + // disabled) as the default scheduler. + RETURN_IF_ERROR(DynamicBatchScheduler::Create( + this, nullptr, 0 /*nice*/, false /* dynamic_batching_enabled */, + 1 /* max_batch_size */, + std::unordered_map< + std::string, bool>() /* enforce_equal_shape_tensors */, + false /* preserve_ordering */, + config_.response_cache().enable() /* response_cache_enable */, + std::set() /* preferred_batch_sizes */, + 0 /* max_queue_delay_microseconds */, &scheduler)); + } + + return SetScheduler(std::move(scheduler)); +} + +Status +TritonModel::Initialize() +{ + for (const auto& instance : instances_) { + RETURN_IF_ERROR(instance->Initialize()); + } + + return Status::Success; +} + +Status +TritonModel::WarmUp() +{ + for (const auto& instance : instances_) { + RETURN_IF_ERROR(instance->WarmUp()); + } + + return Status::Success; +} + +TritonModel::TritonModel( + InferenceServer* server, + const std::shared_ptr& localized_model_dir, + const std::shared_ptr& backend, + const double min_compute_capability, const int64_t version, + const inference::ModelConfig& config, const bool auto_complete_config) + : Model( + min_compute_capability, localized_model_dir->Path(), version, config), + server_(server), min_compute_capability_(min_compute_capability), + auto_complete_config_(auto_complete_config), + localized_model_dir_(localized_model_dir), backend_(backend), + state_(nullptr) +{ +} + +TritonModel::~TritonModel() +{ + // Explicitly delete/finalize all model instances before finalizing + // the model itself. + instances_.clear(); + passive_instances_.clear(); + + // Unregister itself from the rate limiter. Note this should happen + // after all instances are destructed. Destrucing instances ensures + // there are no instance threads waiting on rate limiter for + // receiving their payloads. + server_->GetRateLimiter()->UnregisterModel(this); + + // Model finalization is optional... The TRITONBACKEND_Model + // object is this TritonModel object. + if (backend_->ModelFiniFn() != nullptr) { + LOG_TRITONSERVER_ERROR( + backend_->ModelFiniFn()(reinterpret_cast(this)), + "failed finalizing model"); + } +} + +extern "C" { + +// +// TRITONBACKEND_Model +// +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelName(TRITONBACKEND_Model* model, const char** name) +{ + TritonModel* tm = reinterpret_cast(model); + *name = tm->Name().c_str(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelVersion(TRITONBACKEND_Model* model, uint64_t* version) +{ + TritonModel* tm = reinterpret_cast(model); + *version = tm->Version(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelRepository( + TRITONBACKEND_Model* model, TRITONBACKEND_ArtifactType* artifact_type, + const char** location) +{ + TritonModel* tm = reinterpret_cast(model); + *artifact_type = TRITONBACKEND_ARTIFACT_FILESYSTEM; + *location = tm->LocalizedModelPath().c_str(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelConfig( + TRITONBACKEND_Model* model, const uint32_t config_version, + TRITONSERVER_Message** model_config) +{ + TritonModel* tm = reinterpret_cast(model); + + std::string model_config_json; + Status status = + ModelConfigToJson(tm->Config(), config_version, &model_config_json); + if (!status.IsOk()) { + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + + *model_config = reinterpret_cast( + new TritonServerMessage(std::move(model_config_json))); + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelAutoCompleteConfig( + TRITONBACKEND_Model* model, bool* auto_complete_config) +{ + TritonModel* tm = reinterpret_cast(model); + *auto_complete_config = tm->AutoCompleteConfig(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelSetConfig( + TRITONBACKEND_Model* model, const uint32_t config_version, + TRITONSERVER_Message* model_config) +{ + TritonModel* tm = reinterpret_cast(model); + Status status = tm->UpdateModelConfig(config_version, model_config); + if (!status.IsOk()) { + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelServer( + TRITONBACKEND_Model* model, TRITONSERVER_Server** server) +{ + TritonModel* tm = reinterpret_cast(model); + *server = reinterpret_cast(tm->Server()); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelBackend( + TRITONBACKEND_Model* model, TRITONBACKEND_Backend** backend) +{ + TritonModel* tm = reinterpret_cast(model); + *backend = reinterpret_cast(tm->Backend().get()); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelState(TRITONBACKEND_Model* model, void** state) +{ + TritonModel* tm = reinterpret_cast(model); + *state = tm->State(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelSetState(TRITONBACKEND_Model* model, void* state) +{ + TritonModel* tm = reinterpret_cast(model); + tm->SetState(state); + return nullptr; // success +} + +/// +/// TRITONBACKEND_Request +/// +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_RequestId(TRITONBACKEND_Request* request, const char** id) +{ + InferenceRequest* tr = reinterpret_cast(request); + *id = tr->Id().c_str(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_RequestCorrelationId(TRITONBACKEND_Request* request, uint64_t* id) +{ + InferenceRequest* tr = reinterpret_cast(request); + const InferenceRequest::SequenceId& correlation_id = tr->CorrelationId(); + if (correlation_id.Type() != InferenceRequest::SequenceId::DataType::UINT64) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (tr->LogRequest() + "correlation ID in request is not an unsigned int") + .c_str()); + } + *id = correlation_id.UnsignedIntValue(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_RequestFlags(TRITONBACKEND_Request* request, uint32_t* flags) +{ + InferenceRequest* tr = reinterpret_cast(request); + *flags = tr->Flags(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_RequestCorrelationIdString( + TRITONBACKEND_Request* request, const char** id) +{ + InferenceRequest* tr = reinterpret_cast(request); + const InferenceRequest::SequenceId& correlation_id = tr->CorrelationId(); + if (correlation_id.Type() != InferenceRequest::SequenceId::DataType::STRING) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (tr->LogRequest() + "correlation ID in request is not a string") + .c_str()); + } + *id = correlation_id.StringValue().c_str(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_RequestInputCount(TRITONBACKEND_Request* request, uint32_t* count) +{ + InferenceRequest* tr = reinterpret_cast(request); + *count = tr->ImmutableInputs().size(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_RequestInputName( + TRITONBACKEND_Request* request, const uint32_t index, + const char** input_name) +{ + *input_name = nullptr; + + InferenceRequest* tr = reinterpret_cast(request); + const auto& inputs = tr->ImmutableInputs(); + if (index >= inputs.size()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (tr->LogRequest() + "out of bounds index " + std::to_string(index) + + ": request has " + std::to_string(inputs.size()) + " inputs") + .c_str()); + } + + // The request inputs are not allowed to change once the request + // makes it to the backend, so it is ok to just iterate through the + // map. This linear search is the best we can do given the need for + // the inputs to be in a map and given the typical small number of + // inputs is better than having every request maintain the inputs as + // both map and vector. + uint32_t cnt = 0; + for (const auto& pr : inputs) { + if (cnt++ == index) { + InferenceRequest::Input* in = pr.second; + *input_name = in->Name().c_str(); + break; + } + } + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_RequestInput( + TRITONBACKEND_Request* request, const char* name, + TRITONBACKEND_Input** input) +{ + InferenceRequest* tr = reinterpret_cast(request); + const auto& inputs = tr->ImmutableInputs(); + const auto& itr = inputs.find(name); + if (itr == inputs.end()) { + *input = nullptr; + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (tr->LogRequest() + "unknown request input name " + name).c_str()); + } + + InferenceRequest::Input* in = itr->second; + *input = reinterpret_cast(in); + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_RequestInputByIndex( + TRITONBACKEND_Request* request, const uint32_t index, + TRITONBACKEND_Input** input) +{ + InferenceRequest* tr = reinterpret_cast(request); + const auto& inputs = tr->ImmutableInputs(); + if (index >= inputs.size()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (tr->LogRequest() + "out of bounds index " + std::to_string(index) + + ": request has " + std::to_string(inputs.size()) + " inputs") + .c_str()); + } + + // The request inputs are not allowed to change once the request + // makes it to the backend, so it is ok to just iterate through the + // map. This linear search is the best we can do given the need for + // the inputs to be in a map and given the typical small number of + // inputs is better than having every request maintain the inputs as + // both map and vector. + uint32_t cnt = 0; + for (const auto& pr : inputs) { + if (cnt++ == index) { + InferenceRequest::Input* in = pr.second; + *input = reinterpret_cast(in); + break; + } + } + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_RequestOutputCount( + TRITONBACKEND_Request* request, uint32_t* count) +{ + InferenceRequest* tr = reinterpret_cast(request); + *count = tr->ImmutableRequestedOutputs().size(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_RequestOutputName( + TRITONBACKEND_Request* request, const uint32_t index, + const char** output_name) +{ + *output_name = nullptr; + + InferenceRequest* tr = reinterpret_cast(request); + const auto& routputs = tr->ImmutableRequestedOutputs(); + if (index >= routputs.size()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (tr->LogRequest() + "out of bounds index " + std::to_string(index) + + ": request has " + std::to_string(routputs.size()) + + " requested outputs") + .c_str()); + } + + // The requested outputs are not allowed to change once the request + // makes it to the backend, so it is ok to just iterate through the + // set. This linear search is the best we can do given the requested + // outputs being in a set and given the typical small number of + // requested outputs it should not be a performance issue. + uint32_t cnt = 0; + for (const auto& rout : routputs) { + if (cnt++ == index) { + *output_name = rout.c_str(); + break; + } + } + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_RequestOutputBufferProperties( + TRITONBACKEND_Request* request, const char* name, size_t* byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id) +{ + InferenceRequest* tr = reinterpret_cast(request); + auto status = + tr->OutputBufferProperties(name, byte_size, memory_type, memory_type_id); + if (!status.IsOk()) { + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_RequestRelease( + TRITONBACKEND_Request* request, uint32_t release_flags) +{ + InferenceRequest* tr = reinterpret_cast(request); + std::unique_ptr ur(tr); + InferenceRequest::Release(std::move(ur), release_flags); + return nullptr; // success +} + +/// +/// TRITONBACKEND_State +/// + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_StateUpdate(TRITONBACKEND_State* state) +{ + SequenceState* ts = reinterpret_cast(state); + auto status = ts->Update(); + + if (!status.IsOk()) { + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_StateNew( + TRITONBACKEND_State** state, TRITONBACKEND_Request* request, + const char* name, const TRITONSERVER_DataType datatype, + const int64_t* shape, const uint32_t dims_count) +{ + InferenceRequest* tr = reinterpret_cast(request); + SequenceState* lstate; + std::vector lshape(shape, shape + dims_count); + auto& sequence_state = tr->GetSequenceStates(); + + if (sequence_state == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("unable to add state '") + name + + "'. State configuration is missing for model '" + tr->ModelName() + + "'.") + .c_str()); + } + + Status status = sequence_state->OutputState( + name, TritonToDataType(datatype), lshape, &lstate); + if (!status.IsOk()) { + *state = nullptr; + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + + *state = reinterpret_cast(lstate); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_StateBuffer( + TRITONBACKEND_State* state, void** buffer, const uint64_t buffer_byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id) +{ + SequenceState* to = reinterpret_cast(state); + Status status = Status::Success; + + // If the buffer size exactly matches the buffer available, reuse the + // currently allocated buffer. + if (to->Data()->TotalByteSize() == buffer_byte_size) { + const std::shared_ptr& memory = + reinterpret_cast&>(to->Data()); + + TRITONSERVER_MemoryType current_memory_type; + int64_t current_memory_type_id; + void* lbuffer = + memory->MutableBuffer(¤t_memory_type, ¤t_memory_type_id); + + // If the requested memory type doesn't match the current buffer, allocate a + // new buffer with the requested memory type and memory type id. + if (current_memory_type == *memory_type && + current_memory_type_id == *memory_type_id) { + *buffer = lbuffer; + } else { + std::shared_ptr memory = + std::make_shared( + buffer_byte_size, *memory_type, *memory_type_id); + *buffer = memory->MutableBuffer(memory_type, memory_type_id); + to->RemoveAllData(); + status = to->SetData(memory); + } + } else { + std::shared_ptr memory = std::make_shared( + buffer_byte_size, *memory_type, *memory_type_id); + *buffer = memory->MutableBuffer(memory_type, memory_type_id); + to->RemoveAllData(); + status = to->SetData(memory); + } + + if (!status.IsOk()) { + *buffer = nullptr; + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_StateBufferAttributes( + TRITONBACKEND_State* state, + TRITONSERVER_BufferAttributes** buffer_attributes) +{ + SequenceState* to = reinterpret_cast(state); + to->Data()->BufferAt( + 0, reinterpret_cast(buffer_attributes)); + + return nullptr; // success +} + +// +// TRITONBACKEND_ResponseFactory +// +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ResponseFactoryNew( + TRITONBACKEND_ResponseFactory** factory, TRITONBACKEND_Request* request) +{ + InferenceRequest* tr = reinterpret_cast(request); + std::shared_ptr* response_factory = + new std::shared_ptr(tr->ResponseFactory()); + + *factory = reinterpret_cast(response_factory); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ResponseFactoryDelete(TRITONBACKEND_ResponseFactory* factory) +{ + std::shared_ptr* response_factory = + reinterpret_cast*>(factory); + delete response_factory; + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ResponseFactorySendFlags( + TRITONBACKEND_ResponseFactory* factory, const uint32_t send_flags) +{ + std::shared_ptr* response_factory = + reinterpret_cast*>(factory); + Status status = (*response_factory)->SendFlags(send_flags); + if (!status.IsOk()) { + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + return nullptr; // success +} + +/// +/// TRITONBACKEND_Response +/// +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ResponseNew( + TRITONBACKEND_Response** response, TRITONBACKEND_Request* request) +{ + InferenceRequest* tr = reinterpret_cast(request); + + std::unique_ptr tresp; + Status status = tr->ResponseFactory()->CreateResponse(&tresp); + if (!status.IsOk()) { + *response = nullptr; + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + + *response = reinterpret_cast(tresp.release()); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ResponseNewFromFactory( + TRITONBACKEND_Response** response, TRITONBACKEND_ResponseFactory* factory) +{ + std::shared_ptr* response_factory = + reinterpret_cast*>(factory); + + std::unique_ptr tr; + Status status = (*response_factory)->CreateResponse(&tr); + if (!status.IsOk()) { + *response = nullptr; + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + + *response = reinterpret_cast(tr.release()); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ResponseDelete(TRITONBACKEND_Response* response) +{ + InferenceResponse* tr = reinterpret_cast(response); + delete tr; + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ResponseSetStringParameter( + TRITONBACKEND_Response* response, const char* name, const char* value) +{ + InferenceResponse* tr = reinterpret_cast(response); + Status status = tr->AddParameter(name, value); + if (!status.IsOk()) { + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ResponseSetIntParameter( + TRITONBACKEND_Response* response, const char* name, const int64_t value) +{ + InferenceResponse* tr = reinterpret_cast(response); + Status status = tr->AddParameter(name, value); + if (!status.IsOk()) { + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ResponseSetBoolParameter( + TRITONBACKEND_Response* response, const char* name, const bool value) +{ + InferenceResponse* tr = reinterpret_cast(response); + Status status = tr->AddParameter(name, value); + if (!status.IsOk()) { + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ResponseOutput( + TRITONBACKEND_Response* response, TRITONBACKEND_Output** output, + const char* name, const TRITONSERVER_DataType datatype, + const int64_t* shape, const uint32_t dims_count) +{ + InferenceResponse* tr = reinterpret_cast(response); + std::vector lshape(shape, shape + dims_count); + InferenceResponse::Output* loutput; + Status status = tr->AddOutput( + name, TritonToDataType(datatype), std::move(lshape), &loutput); + if (!status.IsOk()) { + *output = nullptr; + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + + *output = reinterpret_cast(loutput); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ResponseSend( + TRITONBACKEND_Response* response, const uint32_t send_flags, + TRITONSERVER_Error* error) +{ + InferenceResponse* tr = reinterpret_cast(response); + + Status status; + + std::unique_ptr utr(tr); + if (error == nullptr) { + status = InferenceResponse::Send(std::move(utr), send_flags); + } else { + status = InferenceResponse::SendWithStatus( + std::move(utr), send_flags, + Status( + TritonCodeToStatusCode(TRITONSERVER_ErrorCode(error)), + TRITONSERVER_ErrorMessage(error))); + } + + if (!status.IsOk()) { + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + + return nullptr; // success +} + +/// +/// TRITONBACKEND_Input +/// +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_InputProperties( + TRITONBACKEND_Input* input, const char** name, + TRITONSERVER_DataType* datatype, const int64_t** shape, + uint32_t* dims_count, uint64_t* byte_size, uint32_t* buffer_count) +{ + InferenceRequest::Input* ti = + reinterpret_cast(input); + if (name != nullptr) { + *name = ti->Name().c_str(); + } + if (datatype != nullptr) { + *datatype = DataTypeToTriton(ti->DType()); + } + if (shape != nullptr) { + *shape = ti->ShapeWithBatchDim().data(); + } + if (dims_count != nullptr) { + *dims_count = ti->ShapeWithBatchDim().size(); + } + if (byte_size != nullptr) { + *byte_size = ti->Data()->TotalByteSize(); + } + if (buffer_count != nullptr) { + *buffer_count = ti->DataBufferCount(); + } + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_InputPropertiesForHostPolicy( + TRITONBACKEND_Input* input, const char* host_policy_name, const char** name, + TRITONSERVER_DataType* datatype, const int64_t** shape, + uint32_t* dims_count, uint64_t* byte_size, uint32_t* buffer_count) +{ + InferenceRequest::Input* ti = + reinterpret_cast(input); + if (name != nullptr) { + *name = ti->Name().c_str(); + } + if (datatype != nullptr) { + *datatype = DataTypeToTriton(ti->DType()); + } + if (shape != nullptr) { + *shape = ti->ShapeWithBatchDim().data(); + } + if (dims_count != nullptr) { + *dims_count = ti->ShapeWithBatchDim().size(); + } + if (host_policy_name != nullptr) { + if (byte_size != nullptr) { + *byte_size = ti->Data(host_policy_name)->TotalByteSize(); + } + if (buffer_count != nullptr) { + *buffer_count = ti->DataBufferCountForHostPolicy(host_policy_name); + } + } else { + if (byte_size != nullptr) { + *byte_size = ti->Data()->TotalByteSize(); + } + if (buffer_count != nullptr) { + *buffer_count = ti->DataBufferCount(); + } + } + return nullptr; // success +} + + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_InputBuffer( + TRITONBACKEND_Input* input, const uint32_t index, const void** buffer, + uint64_t* buffer_byte_size, TRITONSERVER_MemoryType* memory_type, + int64_t* memory_type_id) +{ + InferenceRequest::Input* ti = + reinterpret_cast(input); + Status status = ti->DataBuffer( + index, buffer, buffer_byte_size, memory_type, memory_type_id); + if (!status.IsOk()) { + *buffer = nullptr; + *buffer_byte_size = 0; + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_InputBufferAttributes( + TRITONBACKEND_Input* input, const uint32_t index, const void** buffer, + TRITONSERVER_BufferAttributes** buffer_attributes) +{ + InferenceRequest::Input* ti = + reinterpret_cast(input); + Status status = ti->DataBufferAttributes( + index, buffer, reinterpret_cast(buffer_attributes)); + if (!status.IsOk()) { + *buffer = nullptr; + *buffer_attributes = nullptr; + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_InputBufferForHostPolicy( + TRITONBACKEND_Input* input, const char* host_policy_name, + const uint32_t index, const void** buffer, uint64_t* buffer_byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id) +{ + InferenceRequest::Input* ti = + reinterpret_cast(input); + + Status status = + (host_policy_name == nullptr) + ? ti->DataBuffer( + index, buffer, buffer_byte_size, memory_type, memory_type_id) + : ti->DataBufferForHostPolicy( + index, buffer, buffer_byte_size, memory_type, memory_type_id, + host_policy_name); + if (!status.IsOk()) { + *buffer = nullptr; + *buffer_byte_size = 0; + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + return nullptr; // success +} + +/// +/// TRITONBACKEND_Output +/// +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_OutputBuffer( + TRITONBACKEND_Output* output, void** buffer, + const uint64_t buffer_byte_size, TRITONSERVER_MemoryType* memory_type, + int64_t* memory_type_id) +{ + InferenceResponse::Output* to = + reinterpret_cast(output); + Status status = to->AllocateDataBuffer( + buffer, buffer_byte_size, memory_type, memory_type_id); + if (!status.IsOk()) { + *buffer = nullptr; + return TRITONSERVER_ErrorNew( + StatusCodeToTritonCode(status.StatusCode()), status.Message().c_str()); + } + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_OutputBufferAttributes( + TRITONBACKEND_Output* output, + TRITONSERVER_BufferAttributes** buffer_attributes) +{ + InferenceResponse::Output* to = + reinterpret_cast(output); + + *buffer_attributes = reinterpret_cast( + to->GetBufferAttributes()); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_BackendAttributeAddPreferredInstanceGroup( + TRITONBACKEND_BackendAttribute* backend_attributes, + const TRITONSERVER_InstanceGroupKind kind, const uint64_t count, + const uint64_t* device_ids, const uint64_t id_count) +{ + auto ba = reinterpret_cast(backend_attributes); + ba->preferred_groups_.emplace_back(); + auto& pg = ba->preferred_groups_.back(); + switch (kind) { + case TRITONSERVER_INSTANCEGROUPKIND_AUTO: + pg.set_kind(inference::ModelInstanceGroup::KIND_AUTO); + break; + case TRITONSERVER_INSTANCEGROUPKIND_CPU: + pg.set_kind(inference::ModelInstanceGroup::KIND_CPU); + break; + case TRITONSERVER_INSTANCEGROUPKIND_GPU: + pg.set_kind(inference::ModelInstanceGroup::KIND_GPU); + break; + case TRITONSERVER_INSTANCEGROUPKIND_MODEL: + pg.set_kind(inference::ModelInstanceGroup::KIND_MODEL); + break; + } + pg.set_count(count); + if (device_ids != nullptr) { + for (size_t i = 0; i < id_count; ++i) { + pg.add_gpus(device_ids[i]); + } + } + return nullptr; +} + +} // extern C + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/backend_model.h b/3rdparty/core-r22.12/src/backend_model.h new file mode 100644 index 0000000000000000000000000000000000000000..4e3941eb278bb826837976fb2f34d415ea85b47b --- /dev/null +++ b/3rdparty/core-r22.12/src/backend_model.h @@ -0,0 +1,133 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include "backend_manager.h" +#include "filesystem.h" +#include "infer_request.h" +#include "model.h" +#include "model_config.pb.h" +#include "status.h" + +namespace triton { namespace core { + +class InferenceServer; +class TritonModelInstance; + +// +// Represents a model. +// +// Inheriting from Model to implement backend APIs +// +class TritonModel : public Model { + public: + static Status Create( + InferenceServer* server, const std::string& model_path, + const triton::common::BackendCmdlineConfigMap& backend_cmdline_config_map, + const triton::common::HostPolicyCmdlineConfigMap& host_policy_map, + const std::string& model_name, const int64_t version, + inference::ModelConfig model_config, const bool is_config_provided, + std::unique_ptr* model); + ~TritonModel(); + + const std::string& LocalizedModelPath() const + { + return localized_model_dir_->Path(); + } + InferenceServer* Server() { return server_; } + bool AutoCompleteConfig() const { return auto_complete_config_; } + Status UpdateModelConfig( + const uint32_t config_version, + TRITONSERVER_Message* updated_config_message); + const std::shared_ptr& Backend() const { return backend_; } + const std::vector>& Instances() const + { + return instances_; + } + void* State() { return state_; } + void SetState(void* state) { state_ = state; } + Status AddInstance( + std::unique_ptr&& instance, const bool passive); + + private: + DISALLOW_COPY_AND_ASSIGN(TritonModel); + + TritonModel( + InferenceServer* server, + const std::shared_ptr& localized_model_dir, + const std::shared_ptr& backend, + const double min_compute_capability, const int64_t version, + const inference::ModelConfig& config, const bool auto_complete_config); + + // Set the scheduler based on the model configuration. The scheduler + // can only be set once for a backend. + Status SetConfiguredScheduler(); + + // Merges the global backend configs with the specific + // backend configs. + static Status ResolveBackendConfigs( + const triton::common::BackendCmdlineConfigMap& backend_cmdline_config_map, + const std::string& backend_name, + triton::common::BackendCmdlineConfig& config); + + // Sets defaults for some backend configurations when none are specified on + // the command line. + static Status SetBackendConfigDefaults( + triton::common::BackendCmdlineConfig& config); + + Status Initialize(); + Status WarmUp(); + + // The server object that owns this model. The model holds this as a + // raw pointer because the lifetime of the server is guaranteed to + // be longer than the lifetime of a model owned by the server. + InferenceServer* server_; + + // The minimum supported compute capability on device. + const double min_compute_capability_; + + // Whether the backend should attempt to auto-complete the model config. + const bool auto_complete_config_; + + // The localized repo directory holding the model. If localization + // required creation of a temporary local copy then that copy will + // persist as along as this object is retained by this model. + std::shared_ptr localized_model_dir_; + + // Backend used by this model. + std::shared_ptr backend_; + + // The model instances for this model. + std::vector> instances_; + std::vector> passive_instances_; + + // Opaque state associated with this model. + void* state_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/backend_model_instance.cc b/3rdparty/core-r22.12/src/backend_model_instance.cc new file mode 100644 index 0000000000000000000000000000000000000000..d91452eccdd4dcfedf3124fd06d87207e836fe06 --- /dev/null +++ b/3rdparty/core-r22.12/src/backend_model_instance.cc @@ -0,0 +1,966 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "backend_model_instance.h" + +#ifndef _WIN32 +#include +#include +#include +#endif +#include "backend_config.h" +#include "backend_model.h" +#include "cuda_utils.h" +#include "metrics.h" +#include "model_config.pb.h" +#include "numa_utils.h" +#include "server.h" +#include "shared_library.h" +#include "triton/common/logging.h" +#include "triton/common/nvtx.h" +#include "tritonserver_apis.h" + +// For unknown reason, windows will not export the TRITONBACKEND_* +// functions declared with dllexport in tritonbackend.h. To get those +// functions exported it is (also?) necessary to mark the definitions +// in this file with dllexport as well. +#if defined(_MSC_VER) +#define TRITONAPI_DECLSPEC __declspec(dllexport) +#elif defined(__GNUC__) +#define TRITONAPI_DECLSPEC __attribute__((__visibility__("default"))) +#else +#define TRITONAPI_DECLSPEC +#endif + +namespace triton { namespace core { + +namespace { +// Utilities for warmup feature +TRITONSERVER_Error* +WarmupResponseAlloc( + TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name, + size_t byte_size, TRITONSERVER_MemoryType preferred_memory_type, + int64_t preferred_memory_type_id, void* userp, void** buffer, + void** buffer_userp, TRITONSERVER_MemoryType* actual_memory_type, + int64_t* actual_memory_type_id) +{ + *buffer = malloc(byte_size); + if (*buffer != nullptr) { + *actual_memory_type = TRITONSERVER_MEMORY_CPU; + *actual_memory_type_id = 0; + return nullptr; + } + + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "failed to allocate output buffer for warmup."); +} + +TRITONSERVER_Error* +WarmupResponseRelease( + TRITONSERVER_ResponseAllocator* allocator, void* buffer, void* buffer_userp, + size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id) +{ + free(buffer); + return nullptr; +} + +ResponseAllocator warmup_allocator = ResponseAllocator( + WarmupResponseAlloc, WarmupResponseRelease, nullptr /* start_fn */); + +void +WarmupResponseComplete( + TRITONSERVER_InferenceResponse* iresponse, const uint32_t flags, + void* userp) +{ + auto res_pair = reinterpret_cast< + std::pair, std::vector*>*>(userp); + if (iresponse != nullptr) { + auto err = TRITONSERVER_InferenceResponseError(iresponse); + if (err != nullptr) { + // The error vector is shared by all requests in the batch for now + static std::mutex res_mtx; + { + std::lock_guard lk(res_mtx); + res_pair->second->emplace_back(TRITONSERVER_ErrorMessage(err)); + } + TRITONSERVER_ErrorDelete(err); + } + // Just delete the response, warmup doesn't check for correctness + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceResponseDelete(iresponse), + "deleting warmup response"); + } + // Last response + if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) != 0) { + res_pair->first.set_value(); + } +} + +void +WarmupRequestComplete( + TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp) +{ + if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) { + // Don't need to release request here, it is managed in WarmupData + if (userp != nullptr) { + auto warmup_promise = reinterpret_cast*>(userp); + warmup_promise->set_value(); + } + } +} + +} // namespace + +TritonModelInstance::TritonModelInstance( + TritonModel* model, const std::string& name, const size_t index, + const TRITONSERVER_InstanceGroupKind kind, const int32_t device_id, + const std::vector& profile_names, const bool passive, + const triton::common::HostPolicyCmdlineConfig& host_policy, + const TritonServerMessage& host_policy_message, + const std::vector& secondary_devices) + : model_(model), name_(name), index_(index), kind_(kind), + device_id_(device_id), host_policy_(host_policy), + host_policy_message_(host_policy_message), profile_names_(profile_names), + passive_(passive), secondary_devices_(secondary_devices), state_(nullptr) +{ +#ifdef TRITON_ENABLE_METRICS + if (Metrics::Enabled()) { + // Use an ID in the metric only for GPU instances. Otherwise use + // METRIC_REPORTER_ID_CPU to indicate no device should be reported in the + // metric. + const int id = (kind_ == TRITONSERVER_INSTANCEGROUPKIND_GPU) + ? device_id_ + : METRIC_REPORTER_ID_CPU; + MetricModelReporter::Create( + model_->Name(), model_->Version(), id, model_->Config().metric_tags(), + &reporter_); + } +#endif // TRITON_ENABLE_METRICS +} + +TritonModelInstance::~TritonModelInstance() +{ + if (triton_backend_thread_.get() != nullptr) { + triton_backend_thread_->StopBackendThread(); + } + + // Model finalization is optional... + if (model_->Backend()->ModelInstanceFiniFn() != nullptr) { + LOG_TRITONSERVER_ERROR( + model_->Backend()->ModelInstanceFiniFn()( + reinterpret_cast(this)), + "failed finalizing model instance"); + } +} + +Status +TritonModelInstance::CreateInstances( + TritonModel* model, + const triton::common::BackendCmdlineConfigMap& backend_cmdline_config_map, + const triton::common::HostPolicyCmdlineConfigMap& host_policy_map, + const inference::ModelConfig& model_config, const bool device_blocking) +{ + static triton::common::HostPolicyCmdlineConfig empty_host_policy; + + // This structure is used to allocate TritonBackendThread to instances on same + // device for device blocking execution policy. + std::map> device_to_thread_map; + + for (const auto& group : model_config.instance_group()) { + std::vector profile_names; + for (const auto& profile_name : group.profile()) { + profile_names.push_back(profile_name); + } + std::vector secondary_devices; + for (const auto& secondary_device : group.secondary_devices()) { + secondary_devices.emplace_back( + inference:: + ModelInstanceGroup_SecondaryDevice_SecondaryDeviceKind_Name( + secondary_device.kind()), + secondary_device.device_id()); + } + for (int32_t c = 0; c < group.count(); ++c) { + std::string instance_name{group.count() > 1 + ? group.name() + "_" + std::to_string(c) + : group.name()}; + const bool passive = group.passive(); + std::vector> + instance_setting; + if (group.kind() == inference::ModelInstanceGroup::KIND_CPU) { + instance_setting.emplace_back( + group.host_policy().empty() ? "cpu" : group.host_policy(), + TRITONSERVER_INSTANCEGROUPKIND_CPU, 0 /* device_id */, + &group.rate_limiter()); + } else if (group.kind() == inference::ModelInstanceGroup::KIND_GPU) { + for (const int32_t device_id : group.gpus()) { + instance_setting.emplace_back( + group.host_policy().empty() ? ("gpu_" + std::to_string(device_id)) + : group.host_policy(), + TRITONSERVER_INSTANCEGROUPKIND_GPU, device_id, + &group.rate_limiter()); + } + } else if (group.kind() == inference::ModelInstanceGroup::KIND_MODEL) { + instance_setting.emplace_back( + group.host_policy().empty() ? "model" : group.host_policy(), + TRITONSERVER_INSTANCEGROUPKIND_MODEL, 0 /* device_id */, + &group.rate_limiter()); + } else { + return Status( + Status::Code::INVALID_ARG, + std::string("instance_group kind ") + + ModelInstanceGroup_Kind_Name(group.kind()) + " not supported"); + } + for (const auto is : instance_setting) { + const auto& kind = std::get<1>(is); + const auto& id = std::get<2>(is); + + const std::string& policy_name = std::get<0>(is); + const triton::common::HostPolicyCmdlineConfig* host_policy; + const auto policy_it = host_policy_map.find(policy_name); + if (policy_it != host_policy_map.end()) { + host_policy = &policy_it->second; + } else { + host_policy = &empty_host_policy; + } + RETURN_IF_ERROR(SetNumaConfigOnThread(*host_policy)); + auto err = CreateInstance( + model, instance_name, c, kind, id, profile_names, passive, + policy_name, *host_policy, *(std::get<3>(is)), device_blocking, + &device_to_thread_map, secondary_devices); + RETURN_IF_ERROR(ResetNumaMemoryPolicy()); + RETURN_IF_ERROR(err); + + // When deploying on GPU, we want to make sure the GPU memory usage + // is within allowed range, otherwise, stop the creation to ensure + // there is sufficient GPU memory for other use. + // We check the usage after loading the instance to better enforcing + // the limit. If we check before loading, we may create instance + // that occupies the rest of available memory which against the purpose + if (kind == TRITONSERVER_INSTANCEGROUPKIND_GPU) { + size_t free, total; + double memory_limit; + RETURN_IF_ERROR(GetDeviceMemoryInfo(id, &free, &total)); + RETURN_IF_ERROR(BackendConfigurationModelLoadGpuFraction( + backend_cmdline_config_map, id, &memory_limit)); + const size_t allow = total * memory_limit; + const size_t used = total - free; + if (used > allow) { + return Status( + Status::Code::UNAVAILABLE, + std::string("can not create model '") + instance_name + + "': memory limit set for " + + TRITONSERVER_InstanceGroupKindString(kind) + " " + + std::to_string(id) + + " has exceeded, model loading is rejected."); + } + } + } + } + } + + return Status::Success; +} + +Status +TritonModelInstance::CreateInstance( + TritonModel* model, const std::string& name, const size_t index, + const TRITONSERVER_InstanceGroupKind kind, const int32_t device_id, + const std::vector& profile_names, const bool passive, + const std::string& host_policy_name, + const triton::common::HostPolicyCmdlineConfig& host_policy, + const inference::ModelRateLimiter& rate_limiter_config, + const bool device_blocking, + std::map>* + device_to_thread_map, + const std::vector& secondary_devices) +{ + // Create the JSON representation of the backend configuration. + triton::common::TritonJson::Value host_policy_json( + triton::common::TritonJson::ValueType::OBJECT); + triton::common::TritonJson::Value policy_setting_json( + host_policy_json, triton::common::TritonJson::ValueType::OBJECT); + for (const auto& pr : host_policy) { + RETURN_IF_ERROR(policy_setting_json.AddString(pr.first.c_str(), pr.second)); + } + + RETURN_IF_ERROR(host_policy_json.Add( + host_policy_name.c_str(), std::move(policy_setting_json))); + TritonServerMessage host_policy_message(host_policy_json); + + std::unique_ptr local_instance(new TritonModelInstance( + model, name, index, kind, device_id, profile_names, passive, host_policy, + host_policy_message, secondary_devices)); + + TRITONBACKEND_ModelInstance* triton_instance = + reinterpret_cast(local_instance.get()); + + // Instance initialization is optional... We must set set shared + // library path to point to the backend directory in case the + // backend library attempts to load additional shared libaries. + if (model->Backend()->ModelInstanceInitFn() != nullptr) { + std::unique_ptr slib; + RETURN_IF_ERROR(SharedLibrary::Acquire(&slib)); + RETURN_IF_ERROR(slib->SetLibraryDirectory(model->Backend()->Directory())); + + TRITONSERVER_Error* err = + model->Backend()->ModelInstanceInitFn()(triton_instance); + + RETURN_IF_ERROR(slib->ResetLibraryDirectory()); + RETURN_IF_TRITONSERVER_ERROR(err); + } + + if (!passive) { + RETURN_IF_ERROR(local_instance->GenerateWarmupData()); + RETURN_IF_ERROR(model->Server()->GetRateLimiter()->RegisterModelInstance( + local_instance.get(), rate_limiter_config)); + RETURN_IF_ERROR(local_instance->SetBackendThread( + kind, device_id, device_blocking, device_to_thread_map)); + } + + RETURN_IF_ERROR(model->AddInstance(std::move(local_instance), passive)); + + return Status::Success; +} + +Status +TritonModelInstance::SetBackendThread( + const TRITONSERVER_InstanceGroupKind kind, const int32_t device_id, + const bool device_blocking, + std::map>* + device_to_thread_map) +{ + if (device_blocking && (kind == TRITONSERVER_INSTANCEGROUPKIND_GPU)) { + auto thread_it = device_to_thread_map->find(device_id); + if (thread_it != device_to_thread_map->end()) { + LOG_VERBOSE(1) << "Using already started backend thread for " << Name() + << " on device " << device_id; + triton_backend_thread_ = thread_it->second; + } + } + if (triton_backend_thread_.get() == nullptr) { + std::unique_ptr local_backend_thread; + RETURN_IF_ERROR(TritonBackendThread::CreateBackendThread( + Name(), this, 0 /* nice */, device_id, &local_backend_thread)); + triton_backend_thread_ = std::move(local_backend_thread); + device_to_thread_map->insert({device_id, triton_backend_thread_}); + } else { + triton_backend_thread_->AddModelInstance(this); + } + RETURN_IF_ERROR(triton_backend_thread_->InitAndWarmUpModelInstance(this)); + + return Status::Success; +} + +Status +TritonModelInstance::GenerateWarmupData() +{ + warmup_samples_.clear(); + for (const auto& warmup_setting : model_->Config().model_warmup()) { + if (warmup_setting.batch_size() == 0) { + LOG_VERBOSE(1) << "Skipping batch 0 warmup sample '" + << warmup_setting.name() << "'"; + continue; + } + LOG_VERBOSE(1) << "Generating warmup sample data for '" + << warmup_setting.name() << "'"; + + // Two passes. First pass to get max byte size for synthetic + // data. Second pass to add original inputs and override inputs + // for control inputs. + int64_t max_zero_byte_size = 0; + int64_t max_random_byte_size = 0; + for (const auto& input_meta : warmup_setting.inputs()) { + auto element_count = + triton::common::GetElementCount(input_meta.second.dims()); + if (element_count == -1) { + return Status( + Status::Code::INVALID_ARG, + "warmup setting expects all variable-size dimensions are specified " + "for input '" + + input_meta.first + "'"); + } + + int64_t batch_byte_size = + element_count * + triton::common::GetDataTypeByteSize(input_meta.second.data_type()); + if (batch_byte_size == 0) { + batch_byte_size = element_count * sizeof(int32_t); + } + + switch (input_meta.second.input_data_type_case()) { + case inference::ModelWarmup_Input::InputDataTypeCase::kZeroData: + max_zero_byte_size = std::max(batch_byte_size, max_zero_byte_size); + break; + case inference::ModelWarmup_Input::InputDataTypeCase::kRandomData: { + // Because Triton expects STRING type to be in special format + // (prepend 4 bytes to specify string length), so using zero data + // for simplicity (4 bytes * element count of zeros). + if (input_meta.second.data_type() == + inference::DataType::TYPE_STRING) { + max_zero_byte_size = std::max(batch_byte_size, max_zero_byte_size); + } else { + max_random_byte_size = + std::max(batch_byte_size, max_random_byte_size); + } + break; + } + default: + break; + } + } + + warmup_samples_.emplace_back(warmup_setting.name(), warmup_setting.count()); + auto& warmup_data = warmup_samples_.back(); + // Create buffers for synthetic data + TRITONSERVER_MemoryType type; + int64_t type_id; + warmup_data.zero_data_.reset(new AllocatedMemory( + max_zero_byte_size, TRITONSERVER_MEMORY_CPU_PINNED /* memory_type */, + 0 /* memory_type_id */)); + char* zero_buffer = warmup_data.zero_data_->MutableBuffer(&type, &type_id); + memset(zero_buffer, 0, max_zero_byte_size); + + warmup_data.random_data_.reset(new AllocatedMemory( + max_random_byte_size, TRITONSERVER_MEMORY_CPU_PINNED /* memory_type */, + 0 /* memory_type_id */)); + char* random_buffer = + warmup_data.random_data_->MutableBuffer(&type, &type_id); + for (int64_t offset = 0; offset < max_random_byte_size; offset++) { + random_buffer[offset] = rand(); + } + + // Prepare the inference request for the specified sample, not using + // in-process C API because the request doesn't go through the same pipeline + // (i.e. no normalization / scheduler) so we need to prepare the request to + // the state just before calling instance execute function. + for (size_t cnt = 0; cnt < warmup_setting.batch_size(); cnt++) { + warmup_data.requests_.emplace_back( + new InferenceRequest(model_, model_->Version())); + auto& lrequest = warmup_data.requests_.back(); + + // Second pass to prepare original inputs. + std::vector> input_sps; + for (const auto& input_meta : warmup_setting.inputs()) { + auto batch1_element_count = + triton::common::GetElementCount(input_meta.second.dims()); + auto batch_byte_size = + batch1_element_count * + triton::common::GetDataTypeByteSize(input_meta.second.data_type()); + if (batch_byte_size == 0) { + batch_byte_size = batch1_element_count * sizeof(int32_t); + } + + const char* allocated_ptr; + switch (input_meta.second.input_data_type_case()) { + case inference::ModelWarmup_Input::InputDataTypeCase::kZeroData: + allocated_ptr = zero_buffer; + break; + case inference::ModelWarmup_Input::InputDataTypeCase::kRandomData: { + if (input_meta.second.data_type() == + inference::DataType::TYPE_STRING) { + allocated_ptr = zero_buffer; + } else { + allocated_ptr = random_buffer; + } + break; + } + case inference::ModelWarmup_Input::InputDataTypeCase:: + kInputDataFile: { + // For data provided from file, we can set buffer in first pass + warmup_data.provided_data_.emplace_back(new std::string()); + auto input_data = warmup_data.provided_data_.back().get(); + RETURN_IF_ERROR(ReadTextFile( + JoinPath({model_->LocalizedModelPath(), kWarmupDataFolder, + input_meta.second.input_data_file()}), + input_data)); + if (input_meta.second.data_type() == + inference::DataType::TYPE_STRING) { + batch_byte_size = input_data->size(); + } else if (((size_t)batch_byte_size) > input_data->size()) { + return Status( + Status::Code::INVALID_ARG, + lrequest->LogRequest() + "warmup setting expects " + + std::to_string(batch_byte_size) + + " bytes, but the data " + "provided from " + + input_meta.second.input_data_file() + "only has " + + std::to_string(input_data->size()) + " bytes"); + } + allocated_ptr = input_data->data(); + break; + } + default: + return Status( + Status::Code::INVALID_ARG, + lrequest->LogRequest() + "warmup setting expects input '" + + input_meta.first + "' to have input_data_type set"); + } + + const inference::ModelInput* input_config; + bool is_original_input = + model_->GetInput(input_meta.first, &input_config).IsOk(); + InferenceRequest::Input* input = nullptr; + std::vector input_meta_shape; + // Append batch size only if the model supports batching + // and not control inpt. + if ((model_->Config().max_batch_size() != 0) && is_original_input) { + input_meta_shape.push_back(1); + } + for (auto d : input_meta.second.dims()) { + input_meta_shape.push_back(d); + } + if (is_original_input) { + RETURN_IF_ERROR(lrequest->AddOriginalInput( + input_meta.first, input_meta.second.data_type(), input_meta_shape, + &input)); + } else { + input_sps.emplace_back(); + RETURN_IF_ERROR(lrequest->AddOverrideInput( + input_meta.first, input_meta.second.data_type(), + (model_->Config().max_batch_size() != 0 ? 1 : 0), + input_meta_shape, &input_sps.back())); + input = input_sps.back().get(); + } + RETURN_IF_ERROR(input->AppendData( + allocated_ptr, batch_byte_size, + TRITONSERVER_MEMORY_CPU /* memory_type */, 0 /* memory_type_id */)); + } + + RETURN_IF_ERROR(lrequest->PrepareForInference()); + // Override inputs must be added after PrepareForInference() is called + for (const auto& sp : input_sps) { + RETURN_IF_ERROR(lrequest->AddOverrideInput(sp)); + } + } + } + + return Status::Success; +} + +void +TritonModelInstance::Schedule( + std::vector>&& requests, + const std::function& OnCompletion) +{ + // Use a thread local vector to avoid needing to malloc each + // time an inference is run. + thread_local std::vector triton_requests(1024); + triton_requests.clear(); + for (auto& r : requests) { + // Load the input states for the inference request. + r->LoadInputStates(); + triton_requests.push_back( + reinterpret_cast(r.release())); + } + + Execute(triton_requests); + + OnCompletion(); +} + +Status +TritonModelInstance::Initialize() +{ + RETURN_IF_ERROR(SetNumaConfigOnThread(HostPolicy())); + return Status::Success; +} + +Status +TritonModelInstance::WarmUp() +{ + // move samples to local variable for scoped cleanup + std::vector lwarmup_samples; + lwarmup_samples.swap(warmup_samples_); + + for (auto& sample : lwarmup_samples) { + for (size_t iteration = 1; iteration <= sample.count_; ++iteration) { + LOG_VERBOSE(1) << "model '" << sample.requests_.back()->ModelName() + << "' instance " << Name() << " is running warmup sample '" + << sample.sample_name_ << "' for iteration " << iteration; + + // request/response complete is asynchronous so use promise to wait for + // completion. Also collects error message from the responses in a vector. + std::vector> request_complete(sample.requests_.size()); + std::vector response_errors; + std::vector, std::vector*>> + response_complete(sample.requests_.size()); + + std::vector triton_requests; + for (size_t i = 0; i < sample.requests_.size(); ++i) { + auto& request = sample.requests_[i]; + request->SetReleaseCallback( + WarmupRequestComplete, &request_complete[i]); + response_complete[i].second = &response_errors; + request->SetResponseCallback( + &warmup_allocator, nullptr, WarmupResponseComplete, + &response_complete[i]); + // Capture timestamp before run to avoid incorrect accumulation from + // sequential warmup runs +#ifdef TRITON_ENABLE_STATS + request->CaptureRequestStartNs(); +#endif // TRITON_ENABLE_STATS + request->CaptureQueueStartNs(); + triton_requests.push_back( + reinterpret_cast(request.get())); + } + + Execute(triton_requests); + + // Wait for warmup sample to complete and check error + for (size_t i = 0; i < sample.requests_.size(); ++i) { + request_complete[i].get_future().get(); + response_complete[i].first.get_future().get(); + } + if (response_errors.size() != 0) { + std::string err_str = + "failed to run warmup sample '" + sample.sample_name_ + "': "; + for (const auto& error : response_errors) { + err_str += (error + "; "); + } + // End warmup as soon as there is failing sample + LOG_VERBOSE(1) << "model '" << sample.requests_.back()->ModelName() + << "' instance " << Name() + << " failed to run warmup sample '" + << sample.sample_name_ << "'"; + return Status(Status::Code::INVALID_ARG, err_str); + } + } + } + + return Status::Success; +} + +void +TritonModelInstance::Execute( + std::vector& triton_requests) +{ + TRITONBACKEND_ModelInstance* triton_model_instance = + reinterpret_cast(this); + TritonBackend::TritonModelInstanceExecFn_t inst_exec_fn = + model_->Backend()->ModelInstanceExecFn(); + + // If there is an error then we retain ownership of 'requests' + // and must send error responses. + TRITONSERVER_Error* err = inst_exec_fn( + triton_model_instance, &triton_requests[0], triton_requests.size()); + if (err != nullptr) { + Status status = Status( + TritonCodeToStatusCode(TRITONSERVER_ErrorCode(err)), + TRITONSERVER_ErrorMessage(err)); + for (TRITONBACKEND_Request* tr : triton_requests) { + std::unique_ptr ur( + reinterpret_cast(tr)); + InferenceRequest::RespondIfError(ur, status, true /* release_requests */); + } + + TRITONSERVER_ErrorDelete(err); + } +} + +Status +TritonModelInstance::TritonBackendThread::CreateBackendThread( + const std::string name, TritonModelInstance* model_instance, const int nice, + const int32_t device_id, + std::unique_ptr* triton_backend_thread) +{ + TritonBackendThread* raw_triton_backend_thread = + new TritonBackendThread(name, model_instance->Model()); + std::unique_ptr runner(raw_triton_backend_thread); + + runner->AddModelInstance(model_instance); + runner->backend_thread_ = + std::thread([raw_triton_backend_thread, nice, device_id]() { + raw_triton_backend_thread->BackendThread(nice, device_id); + }); + + triton_backend_thread->reset(runner.release()); + + return Status::Success; +} + +void +TritonModelInstance::TritonBackendThread::AddModelInstance( + TritonModelInstance* model_instance) +{ + model_instances_.push_back(model_instance); +} + +Status +TritonModelInstance::TritonBackendThread::InitAndWarmUpModelInstance( + TritonModelInstance* model_instance) +{ + // Initialize the instance on the backend thread + auto init_payload = model_->Server()->GetRateLimiter()->GetPayload( + Payload::Operation::INIT, model_instance); + RETURN_IF_ERROR( + model_->Server()->GetRateLimiter()->EnqueuePayload(model_, init_payload)); + RETURN_IF_ERROR(init_payload->Wait()); + + // Warm-up the instance on the backend thread + auto warmup_payload = model_->Server()->GetRateLimiter()->GetPayload( + Payload::Operation::WARM_UP, model_instance); + RETURN_IF_ERROR(model_->Server()->GetRateLimiter()->EnqueuePayload( + model_, warmup_payload)); + RETURN_IF_ERROR(warmup_payload->Wait()); + + return Status::Success; +} + +TritonModelInstance::TritonBackendThread::TritonBackendThread( + const std::string& name, TritonModel* model) + : name_(name), model_(model) +{ +} + +TritonModelInstance::TritonBackendThread::~TritonBackendThread() +{ + StopBackendThread(); +} + +void +TritonModelInstance::TritonBackendThread::StopBackendThread() +{ + if (backend_thread_.joinable()) { + // Signal the backend thread to exit and then wait for it... + auto exit_payload = model_->Server()->GetRateLimiter()->GetPayload( + Payload::Operation::EXIT, model_instances_.back()); + model_->Server()->GetRateLimiter()->EnqueuePayload(model_, exit_payload); + backend_thread_.join(); + } +} + +void +TritonModelInstance::TritonBackendThread::BackendThread( + const int nice, const int32_t device_id) +{ +#ifndef _WIN32 + if (setpriority(PRIO_PROCESS, syscall(SYS_gettid), nice) == 0) { + LOG_VERBOSE(1) << "Starting backend thread for " << name_ << " at nice " + << nice << " on device " << device_id << "..."; + } else { + LOG_VERBOSE(1) << "Starting backend thread for " << name_ + << " at default nice (requested nice " << nice << " failed)" + << " on device " << device_id << "..."; + } +#else + LOG_VERBOSE(1) << "Starting backend thread for " << name_ + << " at default nice on device " << device_id << "..."; +#endif + + bool should_exit = false; + while (!should_exit) { + std::shared_ptr payload; + model_->Server()->GetRateLimiter()->DequeuePayload( + model_instances_, &payload); + NVTX_RANGE(nvtx_, "BackendThread " + name_); + payload->Execute(&should_exit); + model_instances_.push_back(payload->GetInstance()); + // Release the payload to the RateLimiter + model_->Server()->GetRateLimiter()->PayloadRelease(payload); + } + LOG_VERBOSE(1) << "Stopping backend thread for " << name_ << "..."; +} + +extern "C" { + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceName( + TRITONBACKEND_ModelInstance* instance, const char** name) +{ + TritonModelInstance* ti = reinterpret_cast(instance); + *name = ti->Name().c_str(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceKind( + TRITONBACKEND_ModelInstance* instance, TRITONSERVER_InstanceGroupKind* kind) +{ + TritonModelInstance* ti = reinterpret_cast(instance); + *kind = ti->Kind(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceDeviceId( + TRITONBACKEND_ModelInstance* instance, int32_t* device_id) +{ + TritonModelInstance* ti = reinterpret_cast(instance); + *device_id = ti->DeviceId(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceHostPolicy( + TRITONBACKEND_ModelInstance* instance, TRITONSERVER_Message** host_policy) +{ + TritonModelInstance* ti = reinterpret_cast(instance); + *host_policy = const_cast( + reinterpret_cast(&ti->HostPolicyMessage())); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceProfileCount( + TRITONBACKEND_ModelInstance* instance, uint32_t* count) +{ + TritonModelInstance* ti = reinterpret_cast(instance); + *count = ti->Profiles().size(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceProfileName( + TRITONBACKEND_ModelInstance* instance, const uint32_t index, + const char** profile_name) +{ + *profile_name = nullptr; + + TritonModelInstance* ti = reinterpret_cast(instance); + const auto& rprofiles = ti->Profiles(); + if (index >= rprofiles.size()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("out of bounds index ") + std::to_string(index) + + ": instance is configured with " + std::to_string(rprofiles.size()) + + " profiles") + .c_str()); + } + + *profile_name = rprofiles[index].c_str(); + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceSecondaryDeviceCount( + TRITONBACKEND_ModelInstance* instance, uint32_t* count) +{ + TritonModelInstance* ti = reinterpret_cast(instance); + *count = ti->SecondaryDevices().size(); + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceSecondaryDeviceProperties( + TRITONBACKEND_ModelInstance* instance, uint32_t index, const char** kind, + int64_t* id) +{ + TritonModelInstance* ti = reinterpret_cast(instance); + const auto& rsecondarydevices = ti->SecondaryDevices(); + + if (index >= rsecondarydevices.size()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("out of bounds index ") + std::to_string(index) + + ": instance is configured with " + + std::to_string(rsecondarydevices.size()) + " secondary devices") + .c_str()); + } + + *kind = rsecondarydevices[index].kind_.c_str(); + *id = rsecondarydevices[index].id_; + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceIsPassive( + TRITONBACKEND_ModelInstance* instance, bool* is_passive) +{ + TritonModelInstance* ti = reinterpret_cast(instance); + *is_passive = ti->IsPassive(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceModel( + TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Model** model) +{ + TritonModelInstance* ti = reinterpret_cast(instance); + *model = reinterpret_cast(ti->Model()); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceState( + TRITONBACKEND_ModelInstance* instance, void** state) +{ + TritonModelInstance* ti = reinterpret_cast(instance); + *state = ti->State(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceSetState( + TRITONBACKEND_ModelInstance* instance, void* state) +{ + TritonModelInstance* ti = reinterpret_cast(instance); + ti->SetState(state); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceReportStatistics( + TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request* request, + const bool success, const uint64_t exec_start_ns, + const uint64_t compute_start_ns, const uint64_t compute_end_ns, + const uint64_t exec_end_ns) +{ +#ifdef TRITON_ENABLE_STATS + TritonModelInstance* ti = reinterpret_cast(instance); + InferenceRequest* tr = reinterpret_cast(request); + tr->ReportStatistics( + ti->MetricReporter(), success, exec_start_ns, compute_start_ns, + compute_end_ns, exec_end_ns); +#endif // TRITON_ENABLE_STATS + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceReportBatchStatistics( + TRITONBACKEND_ModelInstance* instance, const uint64_t batch_size, + const uint64_t exec_start_ns, const uint64_t compute_start_ns, + const uint64_t compute_end_ns, const uint64_t exec_end_ns) +{ +#ifdef TRITON_ENABLE_STATS + TritonModelInstance* ti = reinterpret_cast(instance); + ti->Model()->MutableStatsAggregator()->UpdateInferBatchStats( + ti->MetricReporter(), batch_size, exec_start_ns, compute_start_ns, + compute_end_ns, exec_end_ns); +#endif // TRITON_ENABLE_STATS + + return nullptr; // success +} + +} // extern C +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/backend_model_instance.h b/3rdparty/core-r22.12/src/backend_model_instance.h new file mode 100644 index 0000000000000000000000000000000000000000..aa8ae94045c2b287c47c14cab30ec7013c01ebcf --- /dev/null +++ b/3rdparty/core-r22.12/src/backend_model_instance.h @@ -0,0 +1,200 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include +#include +#include "constants.h" +#include "memory.h" +#include "metric_model_reporter.h" +#include "model_config.pb.h" +#include "server_message.h" +#include "status.h" +#include "triton/common/sync_queue.h" + +namespace triton { namespace core { + +class TritonModel; +class InferenceRequest; + +// +// Represents a model instance. +// +class TritonModelInstance { + public: + static Status CreateInstances( + TritonModel* model, + const triton::common::BackendCmdlineConfigMap& backend_cmdline_config_map, + const triton::common::HostPolicyCmdlineConfigMap& host_policy_map, + const inference::ModelConfig& model_config, const bool device_blocking); + ~TritonModelInstance(); + + const std::string& Name() const { return name_; } + size_t Index() const { return index_; } + TRITONSERVER_InstanceGroupKind Kind() const { return kind_; } + int32_t DeviceId() const { return device_id_; } + const triton::common::HostPolicyCmdlineConfig& HostPolicy() const + { + return host_policy_; + } + const TritonServerMessage& HostPolicyMessage() const + { + return host_policy_message_; + } + bool IsPassive() const { return passive_; } + const std::vector& Profiles() const { return profile_names_; } + + struct SecondaryDevice { + SecondaryDevice(const std::string kind, const int64_t id) + : kind_(kind), id_(id) + { + } + const std::string kind_; + const int64_t id_; + }; + const std::vector& SecondaryDevices() const + { + return secondary_devices_; + } + + Status Initialize(); + Status WarmUp(); + void Schedule( + std::vector>&& requests, + const std::function& OnCompletion); + + TritonModel* Model() const { return model_; } + void* State() { return state_; } + void SetState(void* state) { state_ = state; } + + MetricModelReporter* MetricReporter() const { return reporter_.get(); } + + private: + DISALLOW_COPY_AND_ASSIGN(TritonModelInstance); + class TritonBackendThread; + TritonModelInstance( + TritonModel* model, const std::string& name, const size_t index, + const TRITONSERVER_InstanceGroupKind kind, const int32_t device_id, + const std::vector& profile_names, const bool passive, + const triton::common::HostPolicyCmdlineConfig& host_policy, + const TritonServerMessage& host_policy_message, + const std::vector& secondary_devices); + static Status CreateInstance( + TritonModel* model, const std::string& name, const size_t index, + const TRITONSERVER_InstanceGroupKind kind, const int32_t device_id, + const std::vector& profile_names, const bool passive, + const std::string& host_policy_name, + const triton::common::HostPolicyCmdlineConfig& host_policy, + const inference::ModelRateLimiter& rate_limiter_config, + const bool device_blocking, + std::map>* + device_to_thread_map, + const std::vector& secondary_devices); + Status SetBackendThread( + const TRITONSERVER_InstanceGroupKind kind, const int32_t device_id, + const bool device_blocking, + std::map>* + device_to_thread_map); + Status GenerateWarmupData(); + + void Execute(std::vector& triton_requests); + + class TritonBackendThread { + public: + static Status CreateBackendThread( + const std::string name, TritonModelInstance* model, const int nice, + const int32_t device_id, + std::unique_ptr* triton_backend_thread); + void AddModelInstance(TritonModelInstance* model_instance); + Status InitAndWarmUpModelInstance(TritonModelInstance* model_instance); + void StopBackendThread(); + ~TritonBackendThread(); + + private: + TritonBackendThread(const std::string& name, TritonModel* model); + void BackendThread(const int nice, const int32_t device_id); + + std::string name_; + + TritonModel* model_; + std::deque model_instances_; + + std::thread backend_thread_; + std::atomic backend_thread_exit_; + }; + std::shared_ptr triton_backend_thread_; + + struct WarmupData { + WarmupData(const std::string& sample_name, const size_t count) + : sample_name_(sample_name), count_(std::max(count, size_t{1})) + { + } + + std::string sample_name_; + size_t count_; + // Using a batch of requests to satisfy batch size, this provides better + // alignment on the batch expected by the model, especially for sequence + // model. + std::vector> requests_; + + // Placeholder for input data + std::unique_ptr zero_data_; + std::unique_ptr random_data_; + std::vector> provided_data_; + }; + std::vector warmup_samples_; + + // The TritonModel object that owns this instance. The instance + // holds this as a raw pointer because the lifetime of the model is + // guaranteed to be longer than the lifetime of an instance owned by the + // model. + TritonModel* model_; + + std::string name_; + size_t index_; + + // For CPU device_id_ is always 0. For GPU device_id_ indicates the + // GPU device to be used by the instance. + TRITONSERVER_InstanceGroupKind kind_; + int32_t device_id_; + const triton::common::HostPolicyCmdlineConfig host_policy_; + TritonServerMessage host_policy_message_; + std::vector profile_names_; + bool passive_; + + std::vector secondary_devices_; + + // Reporter for metrics, or nullptr if no metrics should be reported + std::shared_ptr reporter_; + + // Opaque state associated with this model instance. + void* state_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/buffer_attributes.cc b/3rdparty/core-r22.12/src/buffer_attributes.cc new file mode 100644 index 0000000000000000000000000000000000000000..d184662bd6cbd0fea672b548af30724f89382729 --- /dev/null +++ b/3rdparty/core-r22.12/src/buffer_attributes.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "buffer_attributes.h" + +#include +#include "constants.h" + +namespace triton { namespace core { +void +BufferAttributes::SetByteSize(const size_t& byte_size) +{ + byte_size_ = byte_size; +} + +void +BufferAttributes::SetMemoryType(const TRITONSERVER_MemoryType& memory_type) +{ + memory_type_ = memory_type; +} + +void +BufferAttributes::SetMemoryTypeId(const int64_t& memory_type_id) +{ + memory_type_id_ = memory_type_id; +} + +void +BufferAttributes::SetCudaIpcHandle(void* cuda_ipc_handle) +{ + char* lcuda_ipc_handle = reinterpret_cast(cuda_ipc_handle); + cuda_ipc_handle_.clear(); + std::copy( + lcuda_ipc_handle, lcuda_ipc_handle + CUDA_IPC_STRUCT_SIZE, + std::back_inserter(cuda_ipc_handle_)); +} + +void* +BufferAttributes::CudaIpcHandle() +{ + if (cuda_ipc_handle_.empty()) { + return nullptr; + } else { + return reinterpret_cast(cuda_ipc_handle_.data()); + } +} + +size_t +BufferAttributes::ByteSize() const +{ + return byte_size_; +} + +TRITONSERVER_MemoryType +BufferAttributes::MemoryType() const +{ + return memory_type_; +} + +int64_t +BufferAttributes::MemoryTypeId() const +{ + return memory_type_id_; +} + +BufferAttributes::BufferAttributes( + size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id, char* cuda_ipc_handle) + : byte_size_(byte_size), memory_type_(memory_type), + memory_type_id_(memory_type_id) +{ + // cuda ipc handle size + cuda_ipc_handle_.reserve(CUDA_IPC_STRUCT_SIZE); + + if (cuda_ipc_handle != nullptr) { + std::copy( + cuda_ipc_handle, cuda_ipc_handle + CUDA_IPC_STRUCT_SIZE, + std::back_inserter(cuda_ipc_handle_)); + } +} +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/buffer_attributes.h b/3rdparty/core-r22.12/src/buffer_attributes.h new file mode 100644 index 0000000000000000000000000000000000000000..aa89b3913403379946fe69cc8587638d355749c5 --- /dev/null +++ b/3rdparty/core-r22.12/src/buffer_attributes.h @@ -0,0 +1,79 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include +#include "tritonserver_apis.h" + +#pragma once + +namespace triton { namespace core { +// +// A class to hold information about the buffer allocation. +// +class BufferAttributes { + public: + BufferAttributes( + size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id, char cuda_ipc_handle[64]); + BufferAttributes() + { + memory_type_ = TRITONSERVER_MEMORY_CPU; + memory_type_id_ = 0; + cuda_ipc_handle_.reserve(64); + } + + // Set the buffer byte size + void SetByteSize(const size_t& byte_size); + + // Set the buffer memory_type + void SetMemoryType(const TRITONSERVER_MemoryType& memory_type); + + // Set the buffer memory type id + void SetMemoryTypeId(const int64_t& memory_type_id); + + // Set the cuda ipc handle + void SetCudaIpcHandle(void* cuda_ipc_handle); + + // Get the cuda ipc handle + void* CudaIpcHandle(); + + // Get the byte size + size_t ByteSize() const; + + // Get the memory type + TRITONSERVER_MemoryType MemoryType() const; + + // Get the memory type id + int64_t MemoryTypeId() const; + + private: + size_t byte_size_; + TRITONSERVER_MemoryType memory_type_; + int64_t memory_type_id_; + std::vector cuda_ipc_handle_; +}; +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/constants.h b/3rdparty/core-r22.12/src/constants.h new file mode 100644 index 0000000000000000000000000000000000000000..40d0705586774357e04a9a77983862ab0c3a44f8 --- /dev/null +++ b/3rdparty/core-r22.12/src/constants.h @@ -0,0 +1,108 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include + +namespace triton { namespace core { + +constexpr char kInferHeaderContentLengthHTTPHeader[] = + "Inference-Header-Content-Length"; +constexpr char kAcceptEncodingHTTPHeader[] = "Accept-Encoding"; +constexpr char kContentEncodingHTTPHeader[] = "Content-Encoding"; +constexpr char kContentTypeHeader[] = "Content-Type"; +constexpr char kContentLengthHeader[] = "Content-Length"; + +constexpr char kTensorFlowGraphDefPlatform[] = "tensorflow_graphdef"; +constexpr char kTensorFlowSavedModelPlatform[] = "tensorflow_savedmodel"; +constexpr char kTensorFlowGraphDefFilename[] = "model.graphdef"; +constexpr char kTensorFlowSavedModelFilename[] = "model.savedmodel"; +constexpr char kTensorFlowBackend[] = "tensorflow"; + +constexpr char kTensorRTPlanPlatform[] = "tensorrt_plan"; +constexpr char kTensorRTPlanFilename[] = "model.plan"; +constexpr char kTensorRTBackend[] = "tensorrt"; + +constexpr char kOnnxRuntimeOnnxPlatform[] = "onnxruntime_onnx"; +constexpr char kOnnxRuntimeOnnxFilename[] = "model.onnx"; +constexpr char kOnnxRuntimeBackend[] = "onnxruntime"; + +constexpr char kOpenVINORuntimeOpenVINOFilename[] = "model.xml"; +constexpr char kOpenVINORuntimeBackend[] = "openvino"; + +constexpr char kPyTorchLibTorchPlatform[] = "pytorch_libtorch"; +constexpr char kPyTorchLibTorchFilename[] = "model.pt"; +constexpr char kPyTorchBackend[] = "pytorch"; + +constexpr char kPythonFilename[] = "model.py"; +constexpr char kPythonBackend[] = "python"; + +#ifdef TRITON_ENABLE_ENSEMBLE +constexpr char kEnsemblePlatform[] = "ensemble"; +#endif // TRITON_ENABLE_ENSEMBLE + +constexpr char kTensorRTExecutionAccelerator[] = "tensorrt"; +constexpr char kOpenVINOExecutionAccelerator[] = "openvino"; +constexpr char kGPUIOExecutionAccelerator[] = "gpu_io"; +constexpr char kAutoMixedPrecisionExecutionAccelerator[] = + "auto_mixed_precision"; + +constexpr char kModelConfigPbTxt[] = "config.pbtxt"; + +constexpr char kMetricsLabelModelName[] = "model"; +constexpr char kMetricsLabelModelVersion[] = "version"; +constexpr char kMetricsLabelGpuUuid[] = "gpu_uuid"; + +constexpr char kWarmupDataFolder[] = "warmup"; +constexpr char kInitialStateFolder[] = "initial_state"; + +constexpr uint64_t NANOS_PER_SECOND = 1000000000; +constexpr uint64_t NANOS_PER_MILLIS = 1000000; +constexpr int MAX_GRPC_MESSAGE_SIZE = INT32_MAX; +constexpr uint64_t SEQUENCE_IDLE_DEFAULT_MICROSECONDS = 1000 * 1000; +constexpr size_t STRING_CORRELATION_ID_MAX_LENGTH_BYTES = 128; +constexpr size_t CUDA_IPC_STRUCT_SIZE = 64; + +#ifdef TRITON_ENABLE_METRICS +// MetricModelReporter expects a device ID for GPUs, but we reuse this device +// ID for other metrics as well such as for CPU and Response Cache metrics +constexpr int METRIC_REPORTER_ID_CPU = -1; +constexpr int METRIC_REPORTER_ID_RESPONSE_CACHE = -2; +#endif + +#define TIMESPEC_TO_NANOS(TS) \ + ((TS).tv_sec * triton::core::NANOS_PER_SECOND + (TS).tv_nsec) +#define TIMESPEC_TO_MILLIS(TS) \ + (TIMESPEC_TO_NANOS(TS) / triton::core::NANOS_PER_MILLIS) + +#define DISALLOW_MOVE(TypeName) TypeName(Context&& o) = delete; +#define DISALLOW_COPY(TypeName) TypeName(const TypeName&) = delete; +#define DISALLOW_ASSIGN(TypeName) void operator=(const TypeName&) = delete; +#define DISALLOW_COPY_AND_ASSIGN(TypeName) \ + DISALLOW_COPY(TypeName) \ + DISALLOW_ASSIGN(TypeName) + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/cuda_memory_manager.cc b/3rdparty/core-r22.12/src/cuda_memory_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..eec9206c1f8f8f814b7a5aa16860d2cf6addb3ec --- /dev/null +++ b/3rdparty/core-r22.12/src/cuda_memory_manager.cc @@ -0,0 +1,197 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +#include "cuda_memory_manager.h" + +#include +#include +#include +#include "cuda_utils.h" +#include "triton/common/logging.h" + +namespace { + +#define RETURN_IF_CNMEM_ERROR(S, MSG) \ + do { \ + auto status__ = (S); \ + if (status__ != CNMEM_STATUS_SUCCESS) { \ + return Status( \ + Status::Code::INTERNAL, \ + (MSG) + ": " + cnmemGetErrorString(status__)); \ + } \ + } while (false) + +std::string +PointerToString(void* ptr) +{ + std::stringstream ss; + ss << ptr; + return ss.str(); +} + +} // namespace + +namespace triton { namespace core { + +std::unique_ptr CudaMemoryManager::instance_; +std::mutex CudaMemoryManager::instance_mu_; + +CudaMemoryManager::~CudaMemoryManager() +{ + if (has_allocation_) { + auto status = cnmemFinalize(); + if (status != CNMEM_STATUS_SUCCESS) { + LOG_ERROR << "Failed to finalize CUDA memory manager: [" << status << "] " + << cnmemGetErrorString(status); + } + } +} + +void +CudaMemoryManager::Reset() +{ + std::lock_guard lock(instance_mu_); + instance_.reset(); +} + +Status +CudaMemoryManager::Create(const CudaMemoryManager::Options& options) +{ + // Ensure thread-safe creation of CUDA memory pool + std::lock_guard lock(instance_mu_); + if (instance_ != nullptr) { + LOG_WARNING << "New CUDA memory pools could not be created since they " + "already exists"; + return Status::Success; + } + + std::set supported_gpus; + auto status = GetSupportedGPUs( + &supported_gpus, options.min_supported_compute_capability_); + if (status.IsOk()) { + std::vector devices; + for (auto gpu : supported_gpus) { + const auto it = options.memory_pool_byte_size_.find(gpu); + if ((it != options.memory_pool_byte_size_.end()) && (it->second != 0)) { + devices.emplace_back(); + auto& device = devices.back(); + memset(&device, 0, sizeof(device)); + device.device = gpu; + device.size = it->second; + + LOG_INFO << "CUDA memory pool is created on device " << device.device + << " with size " << device.size; + } + } + + if (!devices.empty()) { + RETURN_IF_CNMEM_ERROR( + cnmemInit(devices.size(), devices.data(), CNMEM_FLAGS_CANNOT_GROW), + std::string("Failed to finalize CUDA memory manager")); + } else { + LOG_INFO << "CUDA memory pool disabled"; + } + + // Use to finalize CNMeM properly when out of scope + instance_.reset(new CudaMemoryManager(!devices.empty())); + } else { + return Status( + status.ErrorCode(), + "Failed to initialize CUDA memory manager: " + status.Message()); + } + + return Status::Success; +} + +Status +CudaMemoryManager::Alloc(void** ptr, uint64_t size, int64_t device_id) +{ + if (instance_ == nullptr) { + return Status( + Status::Code::UNAVAILABLE, "CudaMemoryManager has not been created"); + } else if (!instance_->has_allocation_) { + return Status( + Status::Code::UNAVAILABLE, + "CudaMemoryManager has no preallocated CUDA memory"); + } + + int current_device; + RETURN_IF_CUDA_ERR( + cudaGetDevice(¤t_device), std::string("Failed to get device")); + bool overridden = (current_device != device_id); + if (overridden) { + RETURN_IF_CUDA_ERR( + cudaSetDevice(device_id), std::string("Failed to set device")); + } + + // Defer returning error to make sure the device is recovered + auto err = cnmemMalloc(ptr, size, nullptr); + + if (overridden) { + cudaSetDevice(current_device); + } + + RETURN_IF_CNMEM_ERROR( + err, std::string("Failed to allocate CUDA memory with byte size ") + + std::to_string(size) + " on GPU " + std::to_string(device_id)); + return Status::Success; +} + +Status +CudaMemoryManager::Free(void* ptr, int64_t device_id) +{ + if (instance_ == nullptr) { + return Status( + Status::Code::UNAVAILABLE, "CudaMemoryManager has not been created"); + } else if (!instance_->has_allocation_) { + return Status( + Status::Code::UNAVAILABLE, + "CudaMemoryManager has no preallocated CUDA memory"); + } + + int current_device; + RETURN_IF_CUDA_ERR( + cudaGetDevice(¤t_device), std::string("Failed to get device")); + bool overridden = (current_device != device_id); + if (overridden) { + RETURN_IF_CUDA_ERR( + cudaSetDevice(device_id), std::string("Failed to set device")); + } + + // Defer returning error to make sure the device is recovered + auto err = cnmemFree(ptr, nullptr); + + if (overridden) { + cudaSetDevice(current_device); + } + + RETURN_IF_CNMEM_ERROR( + err, std::string("Failed to deallocate CUDA memory at address ") + + PointerToString(ptr) + " on GPU " + std::to_string(device_id)); + return Status::Success; +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/cuda_memory_manager.h b/3rdparty/core-r22.12/src/cuda_memory_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..cc06d8ca1d3b57ea5cff8c735b27da8e54c248fb --- /dev/null +++ b/3rdparty/core-r22.12/src/cuda_memory_manager.h @@ -0,0 +1,85 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +#pragma once + +#include +#include +#include +#include "status.h" + +namespace triton { namespace core { + +// This is a singleton class responsible for maintaining CUDA memory pool +// used by the inference server. CUDA memory allocations and deallocations +// must be requested via functions provided by this class. +class CudaMemoryManager { + public: + // Options to configure CUDA memory manager. + struct Options { + Options(double cc = 6.0, const std::map& s = {}) + : min_supported_compute_capability_(cc), memory_pool_byte_size_(s) + { + } + + // The minimum compute capability of the supported devices. + double min_supported_compute_capability_; + + // The size of CUDA memory reserved for the specified devices. + // The memory size will be rounded up to align with + // the default granularity (512 bytes). + // No memory will be reserved for devices that is not listed. + std::map memory_pool_byte_size_; + }; + + ~CudaMemoryManager(); + + // Create the memory manager based on 'options' specified. + // Return Status object indicating success or failure. + static Status Create(const Options& options); + + // Allocate CUDA memory on GPU 'device_id' with + // the requested 'size' and return the pointer in 'ptr'. + // Return Status object indicating success or failure. + static Status Alloc(void** ptr, uint64_t size, int64_t device_id); + + // Free the memory allocated by the memory manager on 'device_id'. + // Return Status object indicating success or failure. + static Status Free(void* ptr, int64_t device_id); + + protected: + // Provide explicit control on the lifecycle of the CUDA memory manager, + // for testing only. + static void Reset(); + + private: + CudaMemoryManager(bool has_allocation) : has_allocation_(has_allocation) {} + bool has_allocation_; + static std::unique_ptr instance_; + static std::mutex instance_mu_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/cuda_utils.cc b/3rdparty/core-r22.12/src/cuda_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..e758c3a8a71bd4ebf0d6d69b5176995a1f81e69c --- /dev/null +++ b/3rdparty/core-r22.12/src/cuda_utils.cc @@ -0,0 +1,263 @@ +// Copyright 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "cuda_utils.h" + +#include "model_config_utils.h" +#include "triton/common/nvtx.h" + +namespace triton { namespace core { + +#ifdef TRITON_ENABLE_GPU +void CUDART_CB +MemcpyHost(void* args) +{ + auto* copy_params = reinterpret_cast(args); + memcpy(copy_params->dst_, copy_params->src_, copy_params->byte_size_); + delete copy_params; +} +#endif // TRITON_ENABLE_GPU + +Status +GetDeviceMemoryInfo(const int device_id, size_t* free, size_t* total) +{ + *free = 0; + *total = 0; +#ifdef TRITON_ENABLE_GPU + // Make sure that correct device is set before creating stream and + // then restore the device to what was set by the caller. + int current_device; + auto cuerr = cudaGetDevice(¤t_device); + bool overridden = false; + if (cuerr == cudaSuccess) { + overridden = (current_device != device_id); + if (overridden) { + cuerr = cudaSetDevice(device_id); + } + } + + if (cuerr == cudaSuccess) { + cuerr = cudaMemGetInfo(free, total); + } + + if (overridden) { + cudaSetDevice(current_device); + } + + if (cuerr != cudaSuccess) { + return Status( + Status::Code::INTERNAL, + (std::string("unable to get memory info for device ") + + std::to_string(device_id) + ": " + cudaGetErrorString(cuerr))); + } +#endif // TRITON_ENABLE_GPU + return Status::Success; +} + +Status +EnablePeerAccess(const double min_compute_capability) +{ +#ifdef TRITON_ENABLE_GPU + // If we can't enable peer access for one device pair, the best we can + // do is skipping it... + std::set supported_gpus; + bool all_enabled = false; + if (GetSupportedGPUs(&supported_gpus, min_compute_capability).IsOk()) { + all_enabled = true; + int can_access_peer = false; + for (const auto& host : supported_gpus) { + auto cuerr = cudaSetDevice(host); + + if (cuerr == cudaSuccess) { + for (const auto& peer : supported_gpus) { + if (host == peer) { + continue; + } + + cuerr = cudaDeviceCanAccessPeer(&can_access_peer, host, peer); + if ((cuerr == cudaSuccess) && (can_access_peer == 1)) { + cuerr = cudaDeviceEnablePeerAccess(peer, 0); + } + + all_enabled &= ((cuerr == cudaSuccess) && (can_access_peer == 1)); + } + } + } + } + if (!all_enabled) { + return Status( + Status::Code::UNSUPPORTED, + "failed to enable peer access for some device pairs"); + } +#endif // TRITON_ENABLE_GPU + return Status::Success; +} + +Status +CopyBuffer( + const std::string& msg, const TRITONSERVER_MemoryType src_memory_type, + const int64_t src_memory_type_id, + const TRITONSERVER_MemoryType dst_memory_type, + const int64_t dst_memory_type_id, const size_t byte_size, const void* src, + void* dst, cudaStream_t cuda_stream, bool* cuda_used, bool copy_on_stream) +{ + NVTX_RANGE(nvtx_, "CopyBuffer"); + + *cuda_used = false; + + // For CUDA memcpy, all host to host copy will be blocked in respect to the + // host, so use memcpy() directly. In this case, need to be careful on whether + // the src buffer is valid. + if ((src_memory_type != TRITONSERVER_MEMORY_GPU) && + (dst_memory_type != TRITONSERVER_MEMORY_GPU)) { +#ifdef TRITON_ENABLE_GPU + if (copy_on_stream) { + auto params = new CopyParams(dst, src, byte_size); + cudaLaunchHostFunc( + cuda_stream, MemcpyHost, reinterpret_cast(params)); + *cuda_used = true; + } else { + memcpy(dst, src, byte_size); + } +#else + memcpy(dst, src, byte_size); +#endif // TRITON_ENABLE_GPU + } else { +#ifdef TRITON_ENABLE_GPU + RETURN_IF_CUDA_ERR( + cudaMemcpyAsync(dst, src, byte_size, cudaMemcpyDefault, cuda_stream), + msg + ": failed to perform CUDA copy"); + + *cuda_used = true; +#else + return Status( + Status::Code::INTERNAL, + msg + ": try to use CUDA copy while GPU is not supported"); +#endif // TRITON_ENABLE_GPU + } + + return Status::Success; +} + +void +CopyBufferHandler( + const std::string& msg, const TRITONSERVER_MemoryType src_memory_type, + const int64_t src_memory_type_id, + const TRITONSERVER_MemoryType dst_memory_type, + const int64_t dst_memory_type_id, const size_t byte_size, const void* src, + void* dst, cudaStream_t cuda_stream, void* response_ptr, + triton::common::SyncQueue>* + completion_queue) +{ + bool cuda_used = false; + Status status = CopyBuffer( + msg, src_memory_type, src_memory_type_id, dst_memory_type, + dst_memory_type_id, byte_size, src, dst, cuda_stream, &cuda_used); + completion_queue->Put(std::make_tuple(status, cuda_used, response_ptr)); +} + +#ifdef TRITON_ENABLE_GPU +Status +CheckGPUCompatibility(const int gpu_id, const double min_compute_capability) +{ + // Query the compute capability from the device + cudaDeviceProp cuprops; + cudaError_t cuerr = cudaGetDeviceProperties(&cuprops, gpu_id); + if (cuerr != cudaSuccess) { + return Status( + Status::Code::INTERNAL, + "unable to get CUDA device properties for GPU ID" + + std::to_string(gpu_id) + ": " + cudaGetErrorString(cuerr)); + } + + double compute_compability = cuprops.major + (cuprops.minor / 10.0); + if ((compute_compability > min_compute_capability) || + (abs(compute_compability - min_compute_capability) < 0.01)) { + return Status::Success; + } else { + return Status( + Status::Code::UNSUPPORTED, + "gpu " + std::to_string(gpu_id) + " has compute capability '" + + std::to_string(cuprops.major) + "." + + std::to_string(cuprops.minor) + + "' which is less than the minimum supported of '" + + std::to_string(min_compute_capability) + "'"); + } +} + +Status +GetSupportedGPUs( + std::set* supported_gpus, const double min_compute_capability) +{ + // Make sure set is empty before starting + supported_gpus->clear(); + + int device_cnt; + cudaError_t cuerr = cudaGetDeviceCount(&device_cnt); + if ((cuerr == cudaErrorNoDevice) || (cuerr == cudaErrorInsufficientDriver)) { + device_cnt = 0; + } else if (cuerr != cudaSuccess) { + return Status( + Status::Code::INTERNAL, "unable to get number of CUDA devices: " + + std::string(cudaGetErrorString(cuerr))); + } + + // populates supported_gpus + for (int gpu_id = 0; gpu_id < device_cnt; gpu_id++) { + Status status = CheckGPUCompatibility(gpu_id, min_compute_capability); + if (status.IsOk()) { + supported_gpus->insert(gpu_id); + } + } + return Status::Success; +} + +Status +SupportsIntegratedZeroCopy(const int gpu_id, bool* zero_copy_support) +{ + // Query the device to check if integrated + cudaDeviceProp cuprops; + cudaError_t cuerr = cudaGetDeviceProperties(&cuprops, gpu_id); + if (cuerr != cudaSuccess) { + return Status( + Status::Code::INTERNAL, + "unable to get CUDA device properties for GPU ID" + + std::to_string(gpu_id) + ": " + cudaGetErrorString(cuerr)); + } + + // Zero-copy supported only on integrated GPU when it can map host memory + if (cuprops.integrated && cuprops.canMapHostMemory) { + *zero_copy_support = true; + } else { + *zero_copy_support = false; + } + + return Status::Success; +} + +#endif + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/cuda_utils.h b/3rdparty/core-r22.12/src/cuda_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..abe900be3d5720e9acb328501e03bd45fded5187 --- /dev/null +++ b/3rdparty/core-r22.12/src/cuda_utils.h @@ -0,0 +1,144 @@ +// Copyright 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include "status.h" +#include "triton/common/sync_queue.h" + +#ifdef TRITON_ENABLE_GPU +#include +#endif // TRITON_ENABLE_GPU + +namespace triton { namespace core { + +#ifdef TRITON_ENABLE_GPU +#define RETURN_IF_CUDA_ERR(X, MSG) \ + do { \ + cudaError_t err__ = (X); \ + if (err__ != cudaSuccess) { \ + return Status( \ + Status::Code::INTERNAL, (MSG) + ": " + cudaGetErrorString(err__)); \ + } \ + } while (false) +#endif // TRITON_ENABLE_GPU + +#ifndef TRITON_ENABLE_GPU +using cudaStream_t = void*; +#endif // !TRITON_ENABLE_GPU + +/// Get the memory info for the specified device. +/// \param device_id The device ID. +/// \param free Return free memory in bytes. +/// \param total Return total memory in bytes. +/// \return The error status. A non-OK status means failure to get memory info. +Status GetDeviceMemoryInfo(const int device_id, size_t* free, size_t* total); + +/// Enable peer access for all GPU device pairs +/// \param min_compute_capability The minimum support CUDA compute +/// capability. +/// \return The error status. A non-OK status means not all pairs are enabled +Status EnablePeerAccess(const double min_compute_capability); + +/// Copy buffer from 'src' to 'dst' for given 'byte_size'. The buffer location +/// is identified by the memory type and id, and the corresponding copy will be +/// initiated. +/// \param msg The message to be prepended in error message. +/// \param src_memory_type The memory type CPU/GPU of the source. +/// \param src_memory_type_id The device id of the source. +/// \param dst_memory_type The memory type CPU/GPU of the destination. +/// \param dst_memory_type_id The device id of the destination. +/// \param byte_size The size in bytes to me copied from source to destination. +/// \param src The buffer start address of the source. +/// \param dst The buffer start address of the destination. +/// \param cuda_stream The stream to be associated with, and 0 can be +/// passed for default stream. +/// \param cuda_used returns whether a CUDA memory copy is initiated. If true, +/// the caller should synchronize on the given 'cuda_stream' to ensure data copy +/// is completed. +/// \param copy_on_stream whether the memory copies should be performed in cuda +/// host functions on the 'cuda_stream'. +/// \return The error status. A non-ok status indicates failure to copy the +/// buffer. +Status CopyBuffer( + const std::string& msg, const TRITONSERVER_MemoryType src_memory_type, + const int64_t src_memory_type_id, + const TRITONSERVER_MemoryType dst_memory_type, + const int64_t dst_memory_type_id, const size_t byte_size, const void* src, + void* dst, cudaStream_t cuda_stream, bool* cuda_used, + bool copy_on_stream = false); + +#ifdef TRITON_ENABLE_GPU +/// Validates the compute capability of the GPU indexed +/// \param gpu_id The index of the target GPU. +/// \param min_compute_capability The minimum support CUDA compute +/// capability. +/// \return The error status. A non-OK status means the target GPU is +/// not supported. +Status CheckGPUCompatibility( + const int gpu_id, const double min_compute_capability); + +/// Obtains a set of gpu ids that is supported by triton. +/// \param supported_gpus Returns the set of integers which is +/// populated by ids of supported GPUS +/// \param min_compute_capability The minimum support CUDA compute +/// capability. +/// \return The error status. A non-ok status means there were +/// errors encountered while querying GPU devices. +Status GetSupportedGPUs( + std::set* supported_gpus, const double min_compute_capability); + +/// Checks if the GPU specified is an integrated GPU and supports Zero-copy. +/// \param gpu_id The index of the target GPU. +/// \param zero_copy_support If true, Zero-copy is supported by this GPU. +/// \return The error status. A non-OK status means the target GPU is +/// not supported. +Status SupportsIntegratedZeroCopy(const int gpu_id, bool* zero_copy_support); +#endif + +// Helper around CopyBuffer that updates the completion queue with the returned +// status and cuda_used flag. +void CopyBufferHandler( + const std::string& msg, const TRITONSERVER_MemoryType src_memory_type, + const int64_t src_memory_type_id, + const TRITONSERVER_MemoryType dst_memory_type, + const int64_t dst_memory_type_id, const size_t byte_size, const void* src, + void* dst, cudaStream_t cuda_stream, void* response_ptr, + triton::common::SyncQueue>* + completion_queue); + +struct CopyParams { + CopyParams(void* dst, const void* src, const size_t byte_size) + : dst_(dst), src_(src), byte_size_(byte_size) + { + } + + void* dst_; + const void* src_; + const size_t byte_size_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/dynamic_batch_scheduler.cc b/3rdparty/core-r22.12/src/dynamic_batch_scheduler.cc new file mode 100644 index 0000000000000000000000000000000000000000..c608aa3709a2ece50eb8d02e9a52faf60b6109b9 --- /dev/null +++ b/3rdparty/core-r22.12/src/dynamic_batch_scheduler.cc @@ -0,0 +1,698 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "dynamic_batch_scheduler.h" + +#ifndef _WIN32 +#include +#include +#include +#endif +#include "constants.h" +#include "server.h" +#include "triton/common/logging.h" +#include "triton/common/model_config.h" +#include "triton/common/nvtx.h" + +namespace triton { namespace core { + +bool +IsStaleState(Payload::State payload_state) +{ + return ( + (payload_state == Payload::State::EXECUTING) || + (payload_state == Payload::State::RELEASED)); +} + +DynamicBatchScheduler::DynamicBatchScheduler( + TritonModel* model, TritonModelInstance* model_instance, + const bool dynamic_batching_enabled, const int32_t max_batch_size, + const std::unordered_map& enforce_equal_shape_tensors, + const bool preserve_ordering, const bool response_cache_enable, + const std::set& preferred_batch_sizes, + const uint64_t max_queue_delay_microseconds, + const inference::ModelQueuePolicy& default_queue_policy, + const uint32_t priority_levels, const ModelQueuePolicyMap& queue_policy_map) + : model_(model), model_instance_(model_instance), + model_name_(model->Name()), + dynamic_batching_enabled_(dynamic_batching_enabled), + queue_(default_queue_policy, priority_levels, queue_policy_map), + stop_(false), max_batch_size_((size_t)std::max(1, max_batch_size)), + preferred_batch_sizes_(preferred_batch_sizes), + pending_batch_delay_ns_(max_queue_delay_microseconds * 1000), + pending_batch_size_(0), queued_batch_size_(0), + next_preferred_batch_size_(0), + enforce_equal_shape_tensors_(enforce_equal_shape_tensors), + has_optional_input_(false), preserve_ordering_(preserve_ordering) +{ + rate_limiter_ = model_->Server()->GetRateLimiter(); + // Both the server and model config should specify + // caching enabled for model to utilize response cache. + response_cache_enabled_ = + (model_->Server()->ResponseCacheEnabled() && response_cache_enable); +#ifdef TRITON_ENABLE_METRICS + // Initialize metric reporter for cache statistics if cache enabled + if (response_cache_enabled_) { + MetricModelReporter::Create( + model_name_, model_->Version(), METRIC_REPORTER_ID_RESPONSE_CACHE, + model_->Config().metric_tags(), &reporter_); + } +#endif // TRITON_ENABLE_METRICS + max_preferred_batch_size_ = 0; + for (const auto size : preferred_batch_sizes_) { + max_preferred_batch_size_ = + std::max(max_preferred_batch_size_, (size_t)size); + } + + for (const auto& input : model_->Config().input()) { + if (input.optional()) { + has_optional_input_ = true; + break; + } + } +} + +Status +DynamicBatchScheduler::Create( + TritonModel* model, TritonModelInstance* model_instance, const int nice, + const bool dynamic_batching_enabled, const int32_t max_batch_size, + const std::unordered_map& enforce_equal_shape_tensors, + const bool preserve_ordering, const bool response_cache_enable, + const std::set& preferred_batch_sizes, + const uint64_t max_queue_delay_microseconds, + std::unique_ptr* scheduler) +{ + inference::ModelDynamicBatching batcher_config; + batcher_config.set_preserve_ordering(preserve_ordering); + for (const auto& bs : preferred_batch_sizes) { + batcher_config.add_preferred_batch_size(bs); + } + batcher_config.set_max_queue_delay_microseconds(max_queue_delay_microseconds); + + return Create( + model, model_instance, nice, dynamic_batching_enabled, max_batch_size, + enforce_equal_shape_tensors, batcher_config, response_cache_enable, + scheduler); +} + +Status +DynamicBatchScheduler::Create( + TritonModel* model, TritonModelInstance* model_instance, const int nice, + const bool dynamic_batching_enabled, const int32_t max_batch_size, + const std::unordered_map& enforce_equal_shape_tensors, + const inference::ModelDynamicBatching& batcher_config, + const bool response_cache_enable, std::unique_ptr* scheduler) +{ + std::set preferred_batch_sizes; + for (const auto size : batcher_config.preferred_batch_size()) { + preferred_batch_sizes.insert(size); + } + + DynamicBatchScheduler* dyna_sched = new DynamicBatchScheduler( + model, model_instance, dynamic_batching_enabled, max_batch_size, + enforce_equal_shape_tensors, batcher_config.preserve_ordering(), + response_cache_enable, preferred_batch_sizes, + batcher_config.max_queue_delay_microseconds(), + batcher_config.default_queue_policy(), batcher_config.priority_levels(), + batcher_config.priority_queue_policy()); + std::unique_ptr sched(dyna_sched); + + sched->scheduler_thread_exit_.store(false); + if (dynamic_batching_enabled) { + sched->NewPayload(); + sched->scheduler_thread_ = + std::thread([dyna_sched, nice]() { dyna_sched->BatcherThread(nice); }); + } + + scheduler->reset(sched.release()); + + return Status::Success; +} + +DynamicBatchScheduler::~DynamicBatchScheduler() +{ + // Signal the scheduler thread to exit and then wait for it.. + scheduler_thread_exit_.store(true); + cv_.notify_one(); + if (scheduler_thread_.joinable()) { + scheduler_thread_.join(); + } +} + +Status +DynamicBatchScheduler::Enqueue(std::unique_ptr& request) +{ + if (stop_) { + return Status( + Status::Code::UNAVAILABLE, + request->LogRequest() + + "Server is stopping, scheduler for model has stopped accepting new " + "inference requests"); + } + // If queue start timestamp hasn't been set, queue timer starts at + // the beginning of the queueing and scheduling process. Otherwise, + // dynamic batcher is used as component of another batcher and should not + // overwrite the queue start timestamp. + if (request->QueueStartNs() == 0) { + request->CaptureQueueStartNs(); + INFER_TRACE_ACTIVITY( + request->Trace(), TRITONSERVER_TRACE_QUEUE_START, + request->QueueStartNs()); +#ifdef TRITON_ENABLE_TRACING + request->TraceInputTensors( + TRITONSERVER_TRACE_TENSOR_QUEUE_INPUT, "DynamicBatchScheduler Enqueue"); +#endif // TRITON_ENABLE_TRACING + } + + // Record time at the beginning of the batcher queueing. In the case of + // oldest sequence batcher, this will overwrite the value that was previously + // set by sequence batcher, which is okay as by this point, the previous + // batcher won't be needing this value and it can be safely reused by + // the dynamic batcher. + request->CaptureBatcherStartNs(); + + std::unique_ptr cached_response; + + if (response_cache_enabled_) { + CacheLookUp(request, cached_response); + } + + if (cached_response != nullptr) { + // If there was a cache hit then try sending the cached response + // and release the request. + if (preserve_ordering_) { + // In order to preserve the order, the response send must be + // delegated. + DelegateResponse(request); + } + + // Send cached response and release request + InferenceResponse::Send( + std::move(cached_response), TRITONSERVER_RESPONSE_COMPLETE_FINAL); + InferenceRequest::Release( + std::move(request), TRITONSERVER_REQUEST_RELEASE_ALL); + + return Status::Success; + } + + if (!dynamic_batching_enabled_) { + if (preserve_ordering_ || response_cache_enabled_) { + DelegateResponse(request); + } + // If not using dynamic batching, directly enqueue the + // request to model for execution + auto payload = model_->Server()->GetRateLimiter()->GetPayload( + Payload::Operation::INFER_RUN, nullptr /* TritonModelInstance*/); + payload->AddRequest(std::move(request)); + RETURN_IF_ERROR( + model_->Server()->GetRateLimiter()->EnqueuePayload(model_, payload)); + + } else { + bool wake_batcher = true; + { + std::lock_guard lock(mu_); + + queued_batch_size_ += std::max(1U, request->BatchSize()); + + // Assuming no error is returned, this call takes ownership of + // 'request' and so we can't use it after this point. + RETURN_IF_ERROR(queue_.Enqueue(request->Priority(), request)); + + // If there are any idle runners and the queued batch size is greater or + // equal to next preferred batch size, then wake batcher up to service + // this request. We do the actual wake outside of the lock to avoid + // having the woken thread immediately block on the lock + wake_batcher = + model_->Server()->GetRateLimiter()->PayloadSlotAvailable(model_); + + // We may wake up runner less often if we don't enforce equal shape + // within a batch, otherwise must always wake up runner to check it + if (enforce_equal_shape_tensors_.empty()) { + std::lock_guard exec_lock(*(curr_payload_->GetExecMutex())); + auto payload_state = curr_payload_->GetState(); + wake_batcher &= + (payload_saturated_ || IsStaleState(payload_state) || + (queued_batch_size_ >= next_preferred_batch_size_)); + } + } + + if (wake_batcher) { + cv_.notify_one(); + } + } + + return Status::Success; +} + +void +DynamicBatchScheduler::NewPayload() +{ + curr_payload_ = model_->Server()->GetRateLimiter()->GetPayload( + Payload::Operation::INFER_RUN, model_instance_); + payload_saturated_ = false; +} + +void +DynamicBatchScheduler::BatcherThread(const int nice) +{ +#ifndef _WIN32 + if (setpriority(PRIO_PROCESS, syscall(SYS_gettid), nice) == 0) { + LOG_VERBOSE(1) << "Starting dynamic-batcher thread for " << model_name_ + << " at nice " << nice << "..."; + } else { + LOG_VERBOSE(1) << "Starting dynamic-batcher thread for " << model_name_ + << " at default nice (requested nice " << nice + << " failed)..."; + } +#else + LOG_VERBOSE(1) << "Starting dynamic-batcher thread for " << model_name_ + << " at default nice..."; +#endif + // For debugging/testing, delay start of threads until the queue + // contains the specified number of entries. + size_t delay_cnt = 0; + { + const char* dstr = getenv("TRITONSERVER_DELAY_SCHEDULER"); + if (dstr != nullptr) { + delay_cnt = atoi(dstr); + LOG_VERBOSE(1) << "Delaying batcher thread for " << model_name_ + << " until " << delay_cnt << " queued requests..."; + } + } + + auto wait_for_slots = [this]() { + return model_->Server()->GetRateLimiter()->PayloadSlotAvailable(model_); + }; + const uint64_t default_wait_microseconds = 500 * 1000; + + while (!scheduler_thread_exit_.load()) { + NVTX_RANGE(nvtx_, "DynamicBatcher " + model_name_); + + std::shared_ptr>>> + rejected_requests; + uint64_t wait_microseconds = 0; + + // Hold the lock for as short a time as possible. + { + std::unique_lock lock(mu_); + { + std::lock_guard exec_lock(*(curr_payload_->GetExecMutex())); + auto payload_state = curr_payload_->GetState(); + if (payload_saturated_ || IsStaleState(payload_state)) { + NewPayload(); + next_preferred_batch_size_ = 0; + } + } + + if (delay_cnt > 0) { + // Debugging/testing... wait until queue contains 'delay_cnt' + // items... + wait_microseconds = 10 * 1000; + if (queue_.Size() >= delay_cnt) { + delay_cnt = 0; + } + LOG_VERBOSE(1) << "Delaying batcher thread " << model_name_ << " until " + << delay_cnt + << " queued requests, current total = " << queue_.Size(); + } else if (queue_.Empty()) { + wait_microseconds = default_wait_microseconds; + } else { + if (payload_saturated_) { + continue; + } + cv_.wait(lock, wait_for_slots); + { + std::lock_guard exec_lock( + *(curr_payload_->GetExecMutex())); + + auto payload_state = curr_payload_->GetState(); + if (IsStaleState(payload_state)) { + continue; + } + + // Use dynamic batching to get request(s) to execute. + wait_microseconds = GetDynamicBatch(); + + // Get requests that are rejected from searching dynamic batch. + queue_.ReleaseRejectedRequests(&rejected_requests); + + // Extract batch only if there is pending batch + auto pending_batch_queue_cnt = queue_.PendingBatchCount(); + if ((wait_microseconds == 0) && (pending_batch_queue_cnt != 0)) { + curr_payload_->ReserveRequests(pending_batch_queue_cnt); + for (size_t idx = 0; idx < pending_batch_queue_cnt; ++idx) { + std::unique_ptr request; + auto status = queue_.Dequeue(&request); + if (status.IsOk()) { + if (preserve_ordering_ || response_cache_enabled_) { + DelegateResponse(request); + } + curr_payload_->AddRequest(std::move(request)); + } else { + // The queue is empty which conflicts with pending batch + // count. Send the current batch if any and reset related + // variables. + LOG_ERROR << request->LogRequest() + << "Failed to retrieve request from scheduler queue: " + << status.Message(); + queue_.ResetCursor(); + queued_batch_size_ = 0; + pending_batch_size_ = 0; + break; + } + } + + if (curr_payload_->GetState() == Payload::State::UNINITIALIZED) { + curr_payload_->SetState(Payload::State::READY); + } + + queued_batch_size_ -= pending_batch_size_; + pending_batch_size_ = 0; + } + } + } + + // If no requests are to be handled, wait for notification or + // for the specified timeout before checking the queue again. + if (wait_microseconds > 0) { + std::chrono::microseconds wait_timeout(wait_microseconds); + cv_.wait_for(lock, wait_timeout); + } + } + + if (curr_payload_->GetState() == Payload::State::READY) { + auto callback = [this]() { cv_.notify_one(); }; + curr_payload_->SetCallback(callback); + model_->Server()->GetRateLimiter()->EnqueuePayload(model_, curr_payload_); + } + + // Finish rejected requests if any + if (rejected_requests != nullptr) { + static Status rejected_status = + Status(Status::Code::UNAVAILABLE, "Request timeout expired"); + for (auto& rejected_queue : *rejected_requests) { + for (auto& rejected_request : rejected_queue) { + InferenceRequest::RespondIfError( + rejected_request, rejected_status, true); + } + } + } + } // end runner loop + + LOG_VERBOSE(1) << "Stopping dynamic-batcher thread for " << model_name_ + << "..."; +} + +uint64_t +DynamicBatchScheduler::GetDynamicBatch() +{ + // 'mu_' mutex must be held when this function is called. queue_ + // must not be empty. + + // Examine the new requests. If adding these new requests to the + // pending batch allows a preferred batch size then execute it + // immediately. Stop examining requests if the maximum preferred + // batch size would be exceeded or if the shape of the next request + // does not match the shape of the pending batch. + bool send_now = false; + if (!queue_.IsCursorValid()) { + queue_.ResetCursor(); + pending_batch_size_ = 0; + } + size_t best_preferred_batch_size = 0; + queued_batch_size_ -= queue_.ApplyPolicyAtCursor(); + + // When there is optional input or input shape must be enforced, + // the inputs in the requests must be examined for forming a batch + const bool check_input = + !enforce_equal_shape_tensors_.empty() || has_optional_input_; + auto payload_batch_size = curr_payload_->BatchSize(); + while (!queue_.CursorEnd()) { + const auto batch_size = std::max(1U, queue_.RequestAtCursor()->BatchSize()); + + // If there is no pending batch, then this request is starting a + // new batch. + if ((payload_batch_size + queue_.PendingBatchCount()) == 0) { + // Get the shape of the new batch that is being started... + if (check_input) { + if (!curr_payload_->MutableRequiredEqualInputs() + ->Initialize( + queue_.RequestAtCursor(), enforce_equal_shape_tensors_, + has_optional_input_) + .IsOk()) { + send_now = true; + break; + } + } + } else { + // There is a pending batch and adding this request would make + // the batch size larger than all of the preferred batch sizes, + // so mark the cursor at this point. Not sending the pending batch so + // that we can examine the queue delay of requests that fits in a batch. + if (((payload_batch_size + pending_batch_size_ + batch_size) > + max_preferred_batch_size_) && + (best_preferred_batch_size == 0)) { + best_preferred_batch_size = pending_batch_size_; + queue_.MarkCursor(); + payload_saturated_ = true; + } + if ((payload_batch_size + pending_batch_size_ + batch_size) > + max_batch_size_) { + send_now = true; + break; + } + + // There is a pending batch and it has a different shape then + // this request, so send the pending batch as it is. + if (check_input && + !curr_payload_->MutableRequiredEqualInputs()->HasEqualInputs( + queue_.RequestAtCursor())) { + curr_payload_->MarkSaturated(); + send_now = true; + break; + } + } + + pending_batch_size_ += batch_size; + queue_.AdvanceCursor(); + queued_batch_size_ -= queue_.ApplyPolicyAtCursor(); + + if (preferred_batch_sizes_.find(pending_batch_size_ + payload_batch_size) != + preferred_batch_sizes_.end()) { + best_preferred_batch_size = pending_batch_size_; + queue_.MarkCursor(); + } + } + + // Obatin the age of the oldest pending request to compare with the maximum + // batch queuing delay + uint64_t now_ns = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + uint64_t delay_ns = now_ns - queue_.OldestEnqueueTime(); + bool delay_is_exceeded = + (pending_batch_delay_ns_ != 0) && (delay_ns >= pending_batch_delay_ns_); + + // If we found a preferred batch size and the queue delay hasn't been + // exceeded, then execute that. + if ((best_preferred_batch_size != 0) && !delay_is_exceeded) { + if (pending_batch_delay_ns_ == 0) { + payload_saturated_ = true; + } + pending_batch_size_ = best_preferred_batch_size; + queue_.SetCursorToMark(); + return 0; + } + + // No request in pending batch happens when all queued requests have expired + // timeout and the policies are REJECT + if (queue_.PendingBatchCount() == 0) { + return 0; + } + + // If the delay has been exceeded, or if the current batch can't grow + // any larger then just immediately execute whatever is pending. + if (send_now || ((payload_batch_size + pending_batch_size_) >= + max_preferred_batch_size_)) { + payload_saturated_ = true; + return 0; + } + + if (delay_is_exceeded || (pending_batch_delay_ns_ == 0)) { + return 0; + } + + // Set the next preferred batch size given the pending batch size + auto next_preferred_batch_size_it = preferred_batch_sizes_.upper_bound( + pending_batch_size_ + payload_batch_size); + if (next_preferred_batch_size_it != preferred_batch_sizes_.end()) { + next_preferred_batch_size_ = *next_preferred_batch_size_it; + } else { + next_preferred_batch_size_ = + preferred_batch_sizes_.empty() ? 0 : *preferred_batch_sizes_.begin(); + } + if (next_preferred_batch_size_ != 0) { + next_preferred_batch_size_ -= payload_batch_size; + } + + // By this point, we have not seen the pending batch that should be executed + // immediately. However, if we have scheduled a payload that can be grown and + // not yet in preferred batch size, we should move the pending batch over to + // ensure the model instance will pick up largest available batch even if it + // is not the preferred batch. + if (!payload_saturated_ && (payload_batch_size != 0) && + (preferred_batch_sizes_.find(payload_batch_size) == + preferred_batch_sizes_.end())) { + return 0; + } + + uint64_t wait_ns = pending_batch_delay_ns_ - delay_ns; + // Note that taking request timeout into consideration allows us to reset + // pending batch as soon as it is invalidated. But the cost is that in edge + // case where the timeout will be expired one by one, the thread will be + // waken frequently. + if (queue_.ClosestTimeout() != 0) { + if (now_ns <= queue_.ClosestTimeout()) { + wait_ns = std::min(queue_.ClosestTimeout() - now_ns, wait_ns); + } else { + // A request in pending batch is timed-out, wait for 1 us to force the + // thread to reset the pending batch right the way. + wait_ns = 1000; + } + } + + // Return non-zero wait microseconds to cause this thread to wait + // until the queue delay or the closest timeout has expired. + // Another thread may be awaken due to incoming request to handle the + // pending batch before this thread wakes and that is ok. But if no other + // request comes in then this thread will wake and revisit the pending batch + // (and at that time will then see the delay has been exceeded and will send + // the batch). + return wait_ns / 1000; +} + +void +DynamicBatchScheduler::DelegateResponse( + std::unique_ptr& request) +{ + std::lock_guard lock(completion_queue_mtx_); + completion_queue_.emplace_back(); + auto queue_slot = &completion_queue_.back(); + // Pass raw ptr to lambda for tracking stats from cache and updating + // metric reporter on cache miss stats after insertion + InferenceRequest* raw_request_ptr = request.get(); + + request->SetResponseDelegator( + [this, queue_slot, raw_request_ptr]( + std::unique_ptr&& response, const uint32_t flags) { + if (response_cache_enabled_ && raw_request_ptr->CacheKeyIsSet()) { + // Cache insertion happens here because we need the backend to have + // computed the inference response first in the case of cache miss + auto cache = model_->Server()->GetResponseCache(); + auto status = cache->Insert(*response, raw_request_ptr); + bool cache_miss = + (status.StatusCode() != Status::Code::ALREADY_EXISTS); + if (cache_miss) { +#ifdef TRITON_ENABLE_STATS + // Update cache miss statistics even on failure to insert + // as we still spend time on lookup and attempting to insert + raw_request_ptr->ReportStatisticsCacheMiss(reporter_.get()); +#endif // TRITON_ENABLE_STATS + + if (!status.IsOk()) { + LOG_ERROR << raw_request_ptr->LogRequest() + << "Failed to insert request_hash [" + << raw_request_ptr->CacheKey() + << "] into response cache: " << status.Message(); + } + } // Otherwise do nothing; we update cache hit statistics on Lookup + } + + if (preserve_ordering_) { + { + std::lock_guard lock(completion_queue_mtx_); + queue_slot->emplace_back(std::move(response), flags); + } + FinalizeResponses(); + } else { + InferenceResponse::Send(std::move(response), flags); + } + }); +} + +void +DynamicBatchScheduler::CacheLookUp( + std::unique_ptr& request, + std::unique_ptr& cached_response) +{ + auto cache = model_->Server()->GetResponseCache(); + // Lookup request in cache + std::unique_ptr local_response; + request->ResponseFactory()->CreateResponse(&local_response); + auto status = cache->Lookup(local_response.get(), request.get()); + if (status.IsOk() && (local_response != nullptr)) { + cached_response = std::move(local_response); +#ifdef TRITON_ENABLE_STATS + // Update model metrics/stats on cache hits + // Backends will update metrics as normal on cache misses + request->ReportStatisticsCacheHit(reporter_.get()); +#endif // TRITON_ENABLE_STATS + } +} + +void +DynamicBatchScheduler::FinalizeResponses() +{ + // Need exclusive access of the function to ensure responses are sent + // in order + std::lock_guard lock(finalize_mtx_); + // Finalize the completed payloads in-order as far as possible + std::deque, const uint32_t>> + responses; + { + std::lock_guard queue_lock(completion_queue_mtx_); + while (!completion_queue_.empty() && !completion_queue_.front().empty()) { + bool response_complete = false; + for (auto& response_pair : completion_queue_.front()) { + // Assuming FINAL flag is set only in the last response of the request + response_complete = + ((response_pair.second & TRITONSERVER_RESPONSE_COMPLETE_FINAL) != + 0); + responses.emplace_back(std::move(response_pair)); + } + if (response_complete) { + completion_queue_.pop_front(); + } else { + completion_queue_.front().clear(); + } + } + } + + for (auto& response : responses) { + InferenceResponse::Send(std::move(response.first), response.second); + } +} +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/dynamic_batch_scheduler.h b/3rdparty/core-r22.12/src/dynamic_batch_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..16818a9dcbcff78ae4ea6b5d720313b395079807 --- /dev/null +++ b/3rdparty/core-r22.12/src/dynamic_batch_scheduler.h @@ -0,0 +1,182 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "backend_model.h" +#include "backend_model_instance.h" +#include "model_config.pb.h" +#include "rate_limiter.h" +#include "scheduler.h" +#include "scheduler_utils.h" +#include "status.h" +#include "triton/common/model_config.h" + +namespace triton { namespace core { + +// Scheduler that implements dynamic batching. +class DynamicBatchScheduler : public Scheduler { + public: + // Create a scheduler to support a given number of runners and a run + // function to call when a request is scheduled. + static Status Create( + TritonModel* model, TritonModelInstance* model_instance, const int nice, + const bool dynamic_batching_enabled, const int32_t max_batch_size, + const std::unordered_map& enforce_equal_shape_tensors, + const bool preserve_ordering, const bool response_cache_enable, + const std::set& preferred_batch_sizes, + const uint64_t max_queue_delay_microseconds, + std::unique_ptr* scheduler); + + // Create a scheduler to support a given number of runners and a run + // function to call when a request is scheduled. And the scheduler also + // supports different queue policies for different priority levels. + static Status Create( + TritonModel* model, TritonModelInstance* model_instance, const int nice, + const bool dynamic_batching_enabled, const int32_t max_batch_size, + const std::unordered_map& enforce_equal_shape_tensors, + const inference::ModelDynamicBatching& batcher_config, + const bool response_cache_enable, std::unique_ptr* scheduler); + + ~DynamicBatchScheduler(); + + // \see Scheduler::Enqueue() + Status Enqueue(std::unique_ptr& request) override; + + // \see Scheduler::InflightInferenceCount() + size_t InflightInferenceCount() override + { + std::unique_lock lock(mu_); + if (curr_payload_ != nullptr) { + return queue_.Size() + curr_payload_->RequestCount(); + } + return queue_.Size(); + } + + // \see Scheduler::Stop() + void Stop() override { stop_ = true; } + + MetricModelReporter* MetricReporter() const { return reporter_.get(); } + + private: + DynamicBatchScheduler( + TritonModel* model, TritonModelInstance* model_instance, + const bool dynamic_batching_enabled, const int32_t max_batch_size, + const std::unordered_map& enforce_equal_shape_tensors, + const bool preserve_ordering, const bool response_cache_enable, + const std::set& preferred_batch_sizes, + const uint64_t max_queue_delay_microseconds, + const inference::ModelQueuePolicy& default_queue_policy, + const uint32_t priority_levels, + const ModelQueuePolicyMap& queue_policy_map); + + void BatcherThread(const int nice); + void NewPayload(); + uint64_t GetDynamicBatch(); + void DelegateResponse(std::unique_ptr& request); + void CacheLookUp( + std::unique_ptr& request, + std::unique_ptr& cached_response); + void FinalizeResponses(); + + TritonModel* model_; + TritonModelInstance* model_instance_; + + // Name of the model. + std::string model_name_; + + // True if dynamic batching is enabled. + const bool dynamic_batching_enabled_; + + // Map from priority level to queue holding inference requests for the model + // represented by this scheduler. If priority queues are not supported by the + // scheduler, then priority zero entry is used as the single queue. + PriorityQueue queue_; + bool stop_; + + std::thread scheduler_thread_; + std::atomic scheduler_thread_exit_; + + // Mutex and condvar for signaling scheduler thread + std::mutex mu_; + std::condition_variable cv_; + + std::shared_ptr rate_limiter_; + + std::shared_ptr curr_payload_; + bool payload_saturated_; + + size_t max_batch_size_; + size_t max_preferred_batch_size_; + std::set preferred_batch_sizes_; + uint64_t pending_batch_delay_ns_; + size_t pending_batch_size_; + + size_t queued_batch_size_; + size_t next_preferred_batch_size_; + + // The input tensors that require shape checking before being + // allowed in a batch. As a map from the tensor name to a bool. If + // tensor is in map then its shape must match shape of same tensor + // in requests already in the batch. If value is "true" then + // additional tensor is treated as a shape tensor and the values + // contained in the shape tensor must match same tensor already in + // the batch. + const std::unordered_map enforce_equal_shape_tensors_; + + // Store information on whether the model contains optional inputs. + bool has_optional_input_; + + // If true the ordering of responses matches the order of requests + // even when there are multiple scheduler threads. + const bool preserve_ordering_; + + // If true, the scheduler will try to retrieve responses from cache. + bool response_cache_enabled_; + + // Per completion-id queues to store the ready responses + std::deque< + std::vector, uint32_t>>> + completion_queue_; + // Lock to protect the completion_queues_ + std::mutex completion_queue_mtx_; + + // Preserves the order in which responses are finalized + std::mutex finalize_mtx_; + + // Reporter for metrics, or nullptr if no metrics should be reported + std::shared_ptr reporter_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/ensemble_model.cc b/3rdparty/core-r22.12/src/ensemble_model.cc new file mode 100644 index 0000000000000000000000000000000000000000..b263a6512a78c2148b21f2648f1eb4c95f117113 --- /dev/null +++ b/3rdparty/core-r22.12/src/ensemble_model.cc @@ -0,0 +1,67 @@ +// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "ensemble_model.h" + +#include +#include "constants.h" +#include "ensemble_scheduler.h" +#include "model_config_utils.h" +#include "triton/common/logging.h" + +namespace triton { namespace core { + +Status +EnsembleModel::Create( + InferenceServer* server, const std::string& path, const int64_t version, + const inference::ModelConfig& model_config, const bool is_config_provided, + const double min_compute_capability, std::unique_ptr* model) +{ + // Create the ensemble model. + std::unique_ptr local_model( + new EnsembleModel(min_compute_capability, path, version, model_config)); + + RETURN_IF_ERROR(local_model->Init(is_config_provided)); + + std::unique_ptr scheduler; + RETURN_IF_ERROR(EnsembleScheduler::Create( + local_model->MutableStatsAggregator(), server, model_config, &scheduler)); + RETURN_IF_ERROR(local_model->SetScheduler(std::move(scheduler))); + + LOG_VERBOSE(1) << "ensemble model for " << local_model->Name() << std::endl; + + *model = std::move(local_model); + return Status::Success; +} + +std::ostream& +operator<<(std::ostream& out, const EnsembleModel& pb) +{ + out << "name=" << pb.Name() << std::endl; + return out; +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/ensemble_model.h b/3rdparty/core-r22.12/src/ensemble_model.h new file mode 100644 index 0000000000000000000000000000000000000000..b24df739aad45f7e91cb904f6449ce9258d8c5e3 --- /dev/null +++ b/3rdparty/core-r22.12/src/ensemble_model.h @@ -0,0 +1,60 @@ +// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include "model.h" +#include "model_config.pb.h" +#include "scheduler.h" +#include "status.h" + +namespace triton { namespace core { + +class InferenceServer; + +class EnsembleModel : public Model { + public: + EnsembleModel(EnsembleModel&&) = default; + + static Status Create( + InferenceServer* server, const std::string& path, const int64_t version, + const inference::ModelConfig& model_config, const bool is_config_provided, + const double min_compute_capability, std::unique_ptr* model); + + private: + DISALLOW_COPY_AND_ASSIGN(EnsembleModel); + + explicit EnsembleModel( + const double min_compute_capability, const std::string& model_dir, + const int64_t version, const inference::ModelConfig& config) + : Model(min_compute_capability, model_dir, version, config) + { + } + friend std::ostream& operator<<(std::ostream&, const EnsembleModel&); +}; + +std::ostream& operator<<(std::ostream& out, const EnsembleModel& pb); + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/ensemble_scheduler.cc b/3rdparty/core-r22.12/src/ensemble_scheduler.cc new file mode 100644 index 0000000000000000000000000000000000000000..76d520c3580d1ceb112f3b5def8c73286ec21e8a --- /dev/null +++ b/3rdparty/core-r22.12/src/ensemble_scheduler.cc @@ -0,0 +1,1390 @@ +// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifdef TRITON_ENABLE_ENSEMBLE + +#include "ensemble_scheduler.h" + +#include +#include "cuda_utils.h" +#include "metrics.h" +#include "model.h" +#include "model_config_utils.h" +#include "server.h" +#include "triton/common/logging.h" + +namespace triton { namespace core { + +namespace { + +class EnsembleContext; + +using IterationCount = size_t; + +// Request tracker is passed as 'userp' in RequestRelease function and used +// to manage the lifecycle of the ensemble request +class RequestTracker { + public: + explicit RequestTracker( + std::unique_ptr&& request, uint64_t compute_start_ns, + MetricModelReporter* metric_reporter, + InferenceStatsAggregator* stats_aggregator) + : inflight_request_counter_(1), request_(std::move(request)), + compute_start_ns_(compute_start_ns), metric_reporter_(metric_reporter), + stats_aggregator_(stats_aggregator), status_(Status::Success) + { + } + + std::unique_ptr& Request() { return request_; } + + InferenceStatsAggregator& ContextStatsAggregator() + { + return context_stats_aggregator_; + } + + void IncrementCounter() + { + std::lock_guard lk(mtx_); + inflight_request_counter_++; + } + + bool DecrementCounter() + { + std::lock_guard lk(mtx_); + inflight_request_counter_--; + if (inflight_request_counter_ == 0) { +#ifdef TRITON_ENABLE_STATS + const auto& infer_stats = context_stats_aggregator_.ImmutableInferStats(); + request_->ReportStatisticsWithDuration( + metric_reporter_, status_.IsOk(), compute_start_ns_, + infer_stats.compute_input_duration_ns_, + infer_stats.compute_infer_duration_ns_, + infer_stats.compute_output_duration_ns_); + if (status_.IsOk()) { + stats_aggregator_->UpdateInferBatchStatsWithDuration( + metric_reporter_, std::max(1U, request_->BatchSize()), + infer_stats.compute_input_duration_ns_, + infer_stats.compute_infer_duration_ns_, + infer_stats.compute_output_duration_ns_); + } +#endif + InferenceRequest::Release( + std::move(request_), TRITONSERVER_REQUEST_RELEASE_ALL); + } + return (inflight_request_counter_ == 0); + } + + void SetStatus(const Status& status) + { + std::lock_guard lk(mtx_); + status_ = status; + } + + private: + std::mutex mtx_; + uint32_t inflight_request_counter_; + std::unique_ptr request_; + uint64_t compute_start_ns_; + MetricModelReporter* metric_reporter_; + InferenceStatsAggregator* stats_aggregator_; + InferenceStatsAggregator context_stats_aggregator_; + Status status_; +}; + +// Step is used as 'userp' and keeps ensemble context alive +// until no more internal requests are inflight. +// Step contains metadata, and status for the +// internal infer request +struct Step { + Step( + size_t step_idx, const InferenceRequest::SequenceId& correlation_id, + uint32_t flags) + : correlation_id_(correlation_id), flags_(flags), response_flags_(0), + infer_status_(nullptr), step_idx_(step_idx) + { + } + + std::shared_ptr ctx_; + std::unique_ptr request_; + InferenceRequest::SequenceId correlation_id_; + uint32_t flags_; + + std::mutex output_mtx_; + // Different output map to avoid address conflict from different memory types + std::unordered_map> + cpu_output_map_; + std::unordered_map< + int64_t, std::unordered_map>> + gpu_output_map_; + std::set> updated_tensors_; + uint32_t response_flags_; + TRITONSERVER_Error* infer_status_; + + size_t step_idx_; +}; + +struct TensorData { + struct Metadata { + Metadata() = default; + Metadata( + std::unique_ptr&& data, size_t reference_count) + : data_(std::move(data)), remaining_reference_count_(reference_count), + parameter_override_(false) + { + } + Metadata( + std::unique_ptr&& data, size_t reference_count, + const InferenceRequest::SequenceId& correlation_id, uint32_t flags) + : data_(std::move(data)), remaining_reference_count_(reference_count), + parameter_override_(true), correlation_id_(correlation_id), + flags_(flags) + { + } + std::unique_ptr data_; + size_t remaining_reference_count_; + bool parameter_override_; + InferenceRequest::SequenceId correlation_id_; + uint32_t flags_; + }; + TensorData() = default; + TensorData(const size_t outgoing_steps_count) + : current_iteration_(0), outgoing_steps_count_(outgoing_steps_count), + batch_size_(0) + { + } + + IterationCount AddTensor(std::unique_ptr&& tensor) + { + tensor_.emplace( + current_iteration_, Metadata(std::move(tensor), outgoing_steps_count_)); + return current_iteration_++; + } + + IterationCount AddTensor( + std::unique_ptr&& tensor, + const InferenceRequest::SequenceId& correlation_id, uint32_t flags) + { + tensor_.emplace( + current_iteration_, + Metadata( + std::move(tensor), outgoing_steps_count_, correlation_id, flags)); + return current_iteration_++; + } + + // Tensors associated with the particular ensemble tensor. + // A container is used to handle the decoupled case + // where variable number of tensors will be produced. + // map 'iteration count' to pair of + std::unordered_map tensor_; + size_t current_iteration_; + size_t outgoing_steps_count_; + + // Ensemble may be configured to passing tensor between batching model and + // non-batching model as long as the full shapes match and storing the batch + // size of the generated tensor explicitly for checking and setting proper + // shape for the downstream model request. + size_t batch_size_; +}; + +// EnsembleContext maintains the state of the ensemble request +// +// Using static functions to take advantage of shared_ptr, a copy of the +// shared_ptr will be made when a step is scheduled and it will go out of +// scope after the step's callback is finished. The step's callback will +// schedule new steps if available and the last step will finish the ensemble +// request. +// So we don't have to maintian the context in scheduler as the shared_ptr +// will destroy the context for us if there are no "in-flight" steps. +class EnsembleContext { + public: + EnsembleContext( + MetricModelReporter* metric_reporter, + InferenceStatsAggregator* stats_aggregator, InferenceServer* is, + EnsembleInfo* info, std::unique_ptr& request, + cudaStream_t stream); + + // Perform transition on 'context' state given the information of + // 'completed_step' + static void Proceed( + const std::shared_ptr& context, + const std::unique_ptr& completed_step = nullptr); + + private: + static TRITONSERVER_Error* ResponseAlloc( + TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name, + size_t byte_size, TRITONSERVER_MemoryType preferred_memory_type, + int64_t preferred_memory_type_id, void* userp, void** buffer, + void** buffer_userp, TRITONSERVER_MemoryType* allocated_memory_type, + int64_t* allocated_memory_type_id); + static TRITONSERVER_Error* ResponseRelease( + TRITONSERVER_ResponseAllocator* allocator, void* buffer, + void* buffer_userp, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id); + static TRITONSERVER_Error* OutputBufferQuery( + TRITONSERVER_ResponseAllocator* allocator, void* userp, + const char* tensor_name, size_t* byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id); + static void RequestComplete( + TRITONSERVER_InferenceRequest* request, const uint32_t flags, + void* userp); + static void ResponseComplete( + TRITONSERVER_InferenceResponse* response, const uint32_t flags, + void* userp); + + using StepList = std::vector>; + using VersionMap = std::unordered_map>; + + // Helper function to reshape the given tensor according to the + // config shape and batching info and its actual shape and batching info. + // Note that 'dims' will be in full shape as opposed to 'config_dims'. + // Return the dims after reshape. + std::vector ReshapeTensorDims( + const triton::common::DimsList& config_dims, + const bool config_allow_batching, const size_t tensor_batch_size, + const std::vector& dims); + + // Return the list of step that becomes ready due to tensor update + // from 'completed_step' + Status PrepareSteps( + const std::unique_ptr& completed_step, StepList* steps); + + // Prepare infer stats and call the inference server's function to process + // the infer requests specified in 'steps' + static void ScheduleSteps( + const std::shared_ptr& context, StepList&& steps); + + // Helper function that updates ensemble state given 'completed_step' and + // returns the list of updated tensors in 'updated_tensors' + Status UpdateEnsembleState( + const std::unique_ptr& completed_step, + std::set>* updated_tensors); + + // Helper function that returns a list of 'steps' that should be run under + // current ensemble state. 'updated_tensors' is used so that we don't need to + // iterate all the tensors to determine which step can be run. + Status GetNextSteps( + const std::set>& updated_tensors, + StepList* steps); + + // Helper function that completes the response of the ensemble request + Status FinishEnsemble( + std::unique_ptr&& response = nullptr); + + // Helper function that initialize the 'step' given the info at 'step_idx'. + // The 'step' will have proper request / response provider for the model + Status InitStep( + const size_t step_idx, const IterationCount iteration_count, + std::unique_ptr* step); + + // Helper function that set the output of the ensemble request if it is ready + // and valid. + Status CheckAndSetEnsembleOutput( + const std::set>& updated_tensors, + std::unique_ptr* response); + + InferenceServer* is_; + + EnsembleInfo* info_; + + // All EnsembleContext will use the same CUDA stream managed by + // the ensemble scheduler + cudaStream_t stream_; + + // Mutex to avoid concurrent call on 'PrepareSteps' where ensemble state + // are being modified + std::mutex mutex_; + + size_t inflight_step_counter_; + + // pointer that either points to 'pruned_tensor_to_step_' or to + // 'info_->tensor_to_step_' if all ensemble outputs are requested + std::unordered_map>* tensor_to_step_; + + std::unordered_map> pruned_tensor_to_step_; + std::unordered_map tensor_data_; + + // Handle to all models that may be used in the ensemble + std::unordered_map handles_; + + // Request specific information that obtained from ensemble request and + // should be applied to all internal requests + uint32_t flags_; + std::string request_id_; + InferenceRequest::SequenceId correlation_id_; + uint32_t priority_; + uint64_t timeout_; + + // Objects related to the ensemble infer request + Status ensemble_status_; + RequestTracker* request_tracker_; + + // The allocator that will be used to allocate buffers for the + // inference result tensors. + std::unique_ptr< + TRITONSERVER_ResponseAllocator, + decltype(&TRITONSERVER_ResponseAllocatorDelete)> + allocator_; +}; + +EnsembleContext::EnsembleContext( + MetricModelReporter* metric_reporter, + InferenceStatsAggregator* stats_aggregator, InferenceServer* is, + EnsembleInfo* info, std::unique_ptr& request, + cudaStream_t stream) + : is_(is), info_(info), stream_(stream), inflight_step_counter_(0), + allocator_(nullptr, TRITONSERVER_ResponseAllocatorDelete) +{ + uint64_t compute_start_ns = 0; + INFER_STATS_SET_TIMESTAMP(compute_start_ns); + request_tracker_ = new RequestTracker( + std::move(request), compute_start_ns, metric_reporter, stats_aggregator); + + auto& lrequest = request_tracker_->Request(); + + // Obtain model handles of all models in ensemble request such that + // they have the same lifetime as the ensemble request to avoid unloading + // while the ensemble is executing. + for (const auto& step_info : info_->steps_) { + auto it = handles_.find(step_info.model_name_); + if (it == handles_.end()) { + it = handles_.emplace(std::make_pair(step_info.model_name_, VersionMap())) + .first; + } + auto ver_it = it->second.find(step_info.model_version_); + if (ver_it == it->second.end()) { + std::shared_ptr model = nullptr; + ensemble_status_ = is_->GetModel( + step_info.model_name_, step_info.model_version_, &model); + if (!ensemble_status_.IsOk()) { + break; + } + + it->second.emplace(std::make_pair(step_info.model_version_, model)); + } + } + + // Prune ensemble first if not all outputs are requested + std::set ignored_tensor; + for (const auto& ensemble_output : info_->ensemble_output_shape_) { + ignored_tensor.insert(ensemble_output.first); + } + for (const auto& requested_output : lrequest->ImmutableRequestedOutputs()) { + ignored_tensor.erase(requested_output); + } + if (ignored_tensor.empty()) { + tensor_to_step_ = &(info_->tensor_to_step_); + } else { + pruned_tensor_to_step_ = info_->tensor_to_step_; + tensor_to_step_ = &pruned_tensor_to_step_; + // Backward traversal + std::unordered_map step_requested_output_count; + while (!ignored_tensor.empty()) { + std::set new_ignored_tensor; + for (const auto& output : ignored_tensor) { + auto step_idx = info_->tensor_to_prev_step_[output]; + auto& step = info_->steps_[step_idx]; + auto it = step_requested_output_count.find(step_idx); + if (it == step_requested_output_count.end()) { + auto output_count = step.output_to_tensor_.size(); + it = + step_requested_output_count.emplace(step_idx, output_count).first; + } + // If none of the outputs of the step is requested, + // then the step can be pruned + if (--it->second == 0) { + for (const auto& input : step.input_to_tensor_) { + auto& step_set = pruned_tensor_to_step_[input.second]; + step_set.erase(step_idx); + // If all steps depend on a tensor are pruned, + // then the tensor can be ignored. + if (step_set.empty()) { + new_ignored_tensor.insert(input.second); + } + } + } + } + ignored_tensor.swap(new_ignored_tensor); + } + } + + for (const auto& pair : *tensor_to_step_) { + const auto& requested_outputs = lrequest->ImmutableRequestedOutputs(); + // For requested outputs, add 1 to outgoing count as the ensemble itself + // isn't counted as step. + if (requested_outputs.find(pair.first) != requested_outputs.end()) { + tensor_data_.emplace(pair.first, TensorData(pair.second.size() + 1)); + } else { + tensor_data_.emplace(pair.first, TensorData(pair.second.size())); + } + } + + if (ensemble_status_.IsOk()) { + request_id_ = lrequest->Id(); + correlation_id_ = lrequest->CorrelationId(); + flags_ = lrequest->Flags(); + priority_ = lrequest->Priority(); + timeout_ = lrequest->TimeoutMicroseconds(); + + for (const auto& pr : lrequest->ImmutableInputs()) { + const InferenceRequest::Input* input = pr.second; + auto it = tensor_data_.find(input->Name()); + if (it != tensor_data_.end()) { + auto& tensor_data = it->second; + // Shape() represents reshaped value without batch dimension, + // thus need to fill it if necessary. + std::unique_ptr tensor; + if (lrequest->BatchSize() != 0) { + std::vector shape{lrequest->BatchSize()}; + shape.insert( + shape.end(), input->Shape().begin(), input->Shape().end()); + tensor.reset(new InferenceRequest::Input( + input->Name(), input->DType(), shape)); + } else { + tensor.reset(new InferenceRequest::Input( + input->Name(), input->DType(), input->Shape())); + } + tensor->SetData(input->Data()); + for (const auto& host_policy_data : input->HostPolicyData()) { + tensor->SetData(host_policy_data.first, host_policy_data.second); + } + tensor_data.AddTensor(std::move(tensor)); + tensor_data.batch_size_ = lrequest->BatchSize(); + } else { + ensemble_status_ = Status( + Status::Code::INVALID_ARG, + lrequest->LogRequest() + "unexpected input '" + input->Name() + + "' in request header that does not map to any ensemble inputs"); + } + } + + // Iterate the ensemble optional inputs and add empty tensor data entry + // if the input is not provided + for (const auto& name : info_->optional_inputs_) { + auto it = tensor_data_.find(name); + if ((it != tensor_data_.end()) && it->second.tensor_.empty()) { + it->second.AddTensor(nullptr); + it->second.batch_size_ = lrequest->BatchSize(); + } + } + } + + TRITONSERVER_ResponseAllocator* allocator; + TRITONSERVER_Error* err = TRITONSERVER_ResponseAllocatorNew( + &allocator, ResponseAlloc, ResponseRelease, nullptr /* start_fn */); + if (err == nullptr) { + err = TRITONSERVER_ResponseAllocatorSetQueryFunction( + allocator, OutputBufferQuery); + } + if (err != nullptr) { + ensemble_status_ = Status( + TritonCodeToStatusCode(TRITONSERVER_ErrorCode(err)), + TRITONSERVER_ErrorMessage(err)); + TRITONSERVER_ErrorDelete(err); + } else { + allocator_.reset(allocator); + } +} + +TRITONSERVER_Error* +EnsembleContext::ResponseAlloc( + TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name, + size_t byte_size, TRITONSERVER_MemoryType preferred_memory_type, + int64_t preferred_memory_type_id, void* userp, void** buffer, + void** buffer_userp, TRITONSERVER_MemoryType* allocated_memory_type, + int64_t* allocated_memory_type_id) +{ + *buffer = nullptr; + *buffer_userp = nullptr; + + auto allocated_buffer = std::make_shared( + byte_size, preferred_memory_type, preferred_memory_type_id); + + auto mutable_buffer = allocated_buffer->MutableBuffer( + allocated_memory_type, allocated_memory_type_id); + if ((mutable_buffer != nullptr) || (byte_size == 0)) { + if (byte_size != 0) { + *buffer = static_cast(mutable_buffer); + auto step = reinterpret_cast(userp); + std::lock_guard lk(step->output_mtx_); + if (*allocated_memory_type == TRITONSERVER_MEMORY_GPU) { + step->gpu_output_map_[*allocated_memory_type_id].emplace( + reinterpret_cast(*buffer), std::move(allocated_buffer)); + } else { + step->cpu_output_map_.emplace( + reinterpret_cast(*buffer), std::move(allocated_buffer)); + } + } + LOG_VERBOSE(1) << "Internal response allocation: " << tensor_name + << ", size " << byte_size << ", addr " << *buffer + << ", memory type " << *allocated_memory_type << ", type id " + << *allocated_memory_type_id; + } + + return nullptr; // Success +} + +TRITONSERVER_Error* +EnsembleContext::ResponseRelease( + TRITONSERVER_ResponseAllocator* allocator, void* buffer, void* buffer_userp, + size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id) +{ + LOG_VERBOSE(1) << "Internal response release: " + << "size " << byte_size << ", addr " << buffer; + + // Don't do anything when releasing a buffer since ResponseAlloc + // passes the ownership of the data to ensemble context. + return nullptr; // Success +} + +TRITONSERVER_Error* +EnsembleContext::OutputBufferQuery( + TRITONSERVER_ResponseAllocator* allocator, void* userp, + const char* tensor_name, size_t* byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id) +{ + // Ensemble will always attempt to satisfy any output buffer request + return nullptr; // Success +} + +void +EnsembleContext::RequestComplete( + TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp) +{ + if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) { + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceRequestDelete(request), + "deleting ensemble inference request"); + auto request_tracker = reinterpret_cast(userp); + if (request_tracker->DecrementCounter()) { + delete request_tracker; + } + } +} + +void +EnsembleContext::ResponseComplete( + TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp) +{ + auto step_ptr = std::unique_ptr(reinterpret_cast(userp)); + step_ptr->response_flags_ = flags; + + if (response != nullptr) { + auto err = TRITONSERVER_InferenceResponseError(response); + uint32_t count; + bool parameter_override = false; + InferenceRequest::SequenceId correlation_id{0}; + uint32_t flags = 0; + if (err == nullptr) { + err = TRITONSERVER_InferenceResponseParameterCount(response, &count); + if (err == nullptr) { + for (uint32_t idx = 0; idx < count; idx++) { + const char* name; + TRITONSERVER_ParameterType type; + const void* vvalue; + err = TRITONSERVER_InferenceResponseParameter( + response, idx, &name, &type, &vvalue); + if (err == nullptr) { + if (!strcmp(name, "sequence_id")) { + switch (type) { + case TRITONSERVER_PARAMETER_INT: + correlation_id = InferenceRequest::SequenceId( + *reinterpret_cast(vvalue)); + parameter_override = true; + break; + case TRITONSERVER_PARAMETER_STRING: + correlation_id = InferenceRequest::SequenceId(std::string( + *reinterpret_cast(vvalue))); + parameter_override = true; + break; + default: + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "expected parameter 'sequence_id' to be " + "TRITONSERVER_PARAMETER_INT or " + "TRITONSERVER_PARAMETER_STRING"); + } + } else if (!strcmp(name, "sequence_start")) { + if (type != TRITONSERVER_PARAMETER_BOOL) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "expect paremeter 'sequence_start' to be " + "TRITONSERVER_PARAMETER_BOOL"); + } else { + if (*reinterpret_cast(vvalue)) { + flags |= TRITONSERVER_REQUEST_FLAG_SEQUENCE_START; + } + parameter_override = true; + } + } else if (!strcmp(name, "sequence_end")) { + if (type != TRITONSERVER_PARAMETER_BOOL) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "expect paremeter 'sequence_end' to be " + "TRITONSERVER_PARAMETER_BOOL"); + } else { + if (*reinterpret_cast(vvalue)) { + flags |= TRITONSERVER_REQUEST_FLAG_SEQUENCE_END; + } + parameter_override = true; + } + } + } + } + } + } + if (err == nullptr) { + err = TRITONSERVER_InferenceResponseOutputCount(response, &count); + if (err == nullptr) { + std::lock_guard lock(step_ptr->ctx_->mutex_); + auto& output_to_tensor = + step_ptr->ctx_->info_->steps_[step_ptr->step_idx_] + .output_to_tensor_; + for (uint32_t idx = 0; idx < count; idx++) { + const char* name; + TRITONSERVER_DataType datatype; + const int64_t* shape; + uint64_t dim_count; + const void* base; + size_t byte_size; + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id; + void* userp; + err = TRITONSERVER_InferenceResponseOutput( + response, idx, &name, &datatype, &shape, &dim_count, &base, + &byte_size, &memory_type, &memory_type_id, &userp); + if (err == nullptr) { + auto it = output_to_tensor.find(name); + if (it != output_to_tensor.end()) { + std::unique_ptr tensor( + new InferenceRequest::Input( + it->second, TritonToDataType(datatype), shape, + dim_count)); + + if (byte_size != 0) { + std::lock_guard output_lk(step_ptr->output_mtx_); + if (memory_type == TRITONSERVER_MEMORY_GPU) { + auto& gpu_output_map = + step_ptr->gpu_output_map_[memory_type_id]; + auto it = + gpu_output_map.find(reinterpret_cast(base)); + tensor->SetData(std::move(it->second)); + gpu_output_map.erase(it); + } else { + auto it = step_ptr->cpu_output_map_.find( + reinterpret_cast(base)); + tensor->SetData(std::move(it->second)); + step_ptr->cpu_output_map_.erase(it); + } + } + + auto& tensor_data = step_ptr->ctx_->tensor_data_[it->second]; + if (parameter_override) { + step_ptr->updated_tensors_.emplace( + it->second, tensor_data.AddTensor( + std::move(tensor), correlation_id, flags)); + } else { + step_ptr->updated_tensors_.emplace( + it->second, + tensor_data.AddTensor( + std::move(tensor), step_ptr->correlation_id_, + step_ptr->flags_)); + } + } else { + LOG_VERBOSE(1) + << "in ensemble, an internal response header specified " + "output '" + << name << "' that does not map to any ensemble tensors"; + } + } + if (err != nullptr) { + break; + } + } + } + } + + if (err != nullptr) { + step_ptr->infer_status_ = err; + } + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceResponseDelete(response), + "deleting inference response"); + } + + EnsembleContext::Proceed(step_ptr->ctx_, step_ptr); + // Expecting more responses + if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) { + step_ptr.release(); + } +} + +void +EnsembleContext::Proceed( + const std::shared_ptr& context, + const std::unique_ptr& completed_step) +{ + StepList ready_steps; + Status status = context->PrepareSteps(completed_step, &ready_steps); + if (status.IsOk()) { + ScheduleSteps(context, std::move(ready_steps)); + } +} + +Status +EnsembleContext::PrepareSteps( + const std::unique_ptr& completed_step, StepList* ready_steps) +{ + { + std::lock_guard lock(mutex_); + + // Initialization error, ensemble status will be not ok since the beginning + if (completed_step == nullptr && !ensemble_status_.IsOk()) { + ensemble_status_ = FinishEnsemble(); + } + + if (ensemble_status_.IsOk()) { + StepList res; + std::set> updated_tensors; + ensemble_status_ = UpdateEnsembleState(completed_step, &updated_tensors); + if (ensemble_status_.IsOk()) { + ensemble_status_ = GetNextSteps(updated_tensors, ready_steps); + } + + // Check and send ensemble response + if ((!ensemble_status_.IsOk()) || (inflight_step_counter_ == 0) || + info_->is_decoupled_) { + std::unique_ptr response; + if (ensemble_status_.IsOk()) { + ensemble_status_ = + CheckAndSetEnsembleOutput(updated_tensors, &response); + } + ensemble_status_ = FinishEnsemble(std::move(response)); + } + } + return ensemble_status_; + } +} + +Status +EnsembleContext::UpdateEnsembleState( + const std::unique_ptr& completed_step, + std::set>* updated_tensors) +{ + updated_tensors->clear(); + if (completed_step == nullptr) { + for (const auto& tensor_data : tensor_data_) { + if (!tensor_data.second.tensor_.empty()) { + updated_tensors->emplace(tensor_data.first, 0); + } + } + } else { + if (completed_step->response_flags_ & + TRITONSERVER_RESPONSE_COMPLETE_FINAL) { + inflight_step_counter_--; + } + RETURN_IF_TRITONSERVER_ERROR(completed_step->infer_status_); + updated_tensors->swap(completed_step->updated_tensors_); + } + return Status::Success; +} + +Status +EnsembleContext::GetNextSteps( + const std::set>& updated_tensors, + StepList* steps) +{ + steps->clear(); + + std::set> next_step_idx; + // Get steps whose tensors used for input are set + for (const auto updated_tensor : updated_tensors) { + const auto& step_idx = (*tensor_to_step_)[updated_tensor.first]; + for (const auto& idx : step_idx) { + bool ready = true; + for (const auto& input_pair : info_->steps_[idx].input_to_tensor_) { + auto& tensor = tensor_data_[input_pair.second].tensor_; + if (tensor.empty()) { + ready = false; + break; + } else { + // Check if other inputs have tensor with corresponding iteration + // count + if (tensor.find(updated_tensor.second) == tensor.end()) { + ready = false; + break; + } + } + } + if (ready) { + next_step_idx.emplace(idx, updated_tensor.second); + } + } + } + + for (const auto& idx : next_step_idx) { + steps->emplace_back(); + RETURN_IF_ERROR(InitStep(idx.first, idx.second, &(steps->back()))); + } + inflight_step_counter_ += steps->size(); + + return Status::Success; +} + +Status +EnsembleContext::InitStep( + const size_t step_idx, const IterationCount iteration_count, + std::unique_ptr* step) +{ + const auto& istep = info_->steps_[step_idx]; + auto& version_map = handles_[istep.model_name_]; + auto& model = version_map[istep.model_version_]; + + const bool allow_batching = (model->Config().max_batch_size() > 0); + + auto irequest = std::unique_ptr( + new InferenceRequest(model, istep.model_version_)); + + // Store the pointers to tensors used so that we can prune them afterward. + // Can't prune the tensor in the input loop below as it may be used by + // multiple inputs in the same step. + std::map releasing_tensors; + + // Set inputs in request, prepare input map, + // and set overridden parameter if any. + auto correlation_id = correlation_id_; + auto flags = flags_; + bool parameter_set = false; + for (const auto& pair : istep.input_to_tensor_) { + auto& tensor_data = tensor_data_[pair.second]; + auto& tensor = tensor_data.tensor_[iteration_count]; + + // nullptr if and only if the tensor is optional ensemble input and + // not provided in the ensemble request. In such case, we don't add + // the input and expect the ensemble pipeline is configured correctly + // (the input to the inner model is also optional) + if (tensor.data_ != nullptr) { + // If the actual shape and config shape agree with each other without + // considering batch size, non-batch / batch conversion are not required. + const inference::ModelInput* input_config; + model->GetInput(pair.first, &input_config); + auto shape = ReshapeTensorDims( + input_config->dims(), allow_batching, tensor_data.batch_size_, + tensor.data_->OriginalShape()); + + InferenceRequest::Input* input; + RETURN_IF_ERROR(irequest->AddOriginalInput( + pair.first, tensor.data_->DType(), shape, &input)); + RETURN_IF_ERROR(input->SetData(tensor.data_->Data())); + for (const auto& host_policy_data : tensor.data_->HostPolicyData()) { + RETURN_IF_ERROR( + input->SetData(host_policy_data.first, host_policy_data.second)); + } + } + + releasing_tensors.emplace(&tensor_data, &tensor.remaining_reference_count_); + + if (tensor.parameter_override_) { + if (parameter_set && ((correlation_id != tensor.correlation_id_) || + (flags != tensor.flags_))) { + LOG_ERROR << irequest->LogRequest() + << "Different set of response parameters are set for '" + << istep.model_name_ << "'. Parameter correlation ID " + << correlation_id << ", flags " << flags << " is used."; + continue; + } + correlation_id = tensor.correlation_id_; + flags = tensor.flags_; + parameter_set = true; + } + } + + // Prune the tensor if it is not needed by other steps + for (auto& releasing_pair : releasing_tensors) { + if ((--(*releasing_pair.second)) == 0) { + releasing_pair.first->tensor_.erase(iteration_count); + } + } + + // Set requested outputs in request header + for (const auto& pair : istep.output_to_tensor_) { + irequest->AddOriginalRequestedOutput(pair.first); + } + + step->reset(new Step(step_idx, correlation_id, flags)); + + irequest->SetId(request_id_); + irequest->SetCorrelationId(correlation_id); + irequest->SetFlags(flags); + irequest->SetPriority(priority_); + irequest->SetTimeoutMicroseconds(timeout_); +#ifdef TRITON_ENABLE_STATS + irequest->SetSecondaryStatsAggregator( + &request_tracker_->ContextStatsAggregator()); +#endif + irequest->SetResponseCallback( + reinterpret_cast(allocator_.get()), step->get(), + ResponseComplete, step->get()); + irequest->SetReleaseCallback(RequestComplete, request_tracker_); + + RETURN_IF_ERROR(irequest->PrepareForInference()); + +#ifdef TRITON_ENABLE_TRACING + auto& parent_trace = request_tracker_->Request()->Trace(); + if (parent_trace != nullptr) { + irequest->SetTrace(parent_trace->SpawnChildTrace()); + irequest->Trace()->SetModelName(irequest->ModelName()); + irequest->Trace()->SetModelVersion(irequest->ActualModelVersion()); + } +#endif + + // Record the batch size of output in advance as + // there is no other way to access it later on. + for (const auto& pair : istep.output_to_tensor_) { + auto& output_data_ = tensor_data_[pair.second]; + output_data_.batch_size_ = irequest->BatchSize(); + } + + (*step)->request_ = std::move(irequest); + + return Status::Success; +} + +std::vector +EnsembleContext::ReshapeTensorDims( + const triton::common::DimsList& config_dims, + const bool config_allow_batching, const size_t tensor_batch_size, + const std::vector& dims) +{ + bool reshaped = false; + std::vector res; + + // Only attempt to reshape if one setting is batchable while the other is not, + // the case of two mismatched batchable shapes is not considered. + // If the actual shape and config shape agree with each other without + // considering batch size, non-batch / batch conversion are not required. + if (config_allow_batching != (tensor_batch_size != 0)) { + // expect batching but the tensor is generated from nobatching model + if (config_allow_batching) { + if (triton::common::CompareDimsWithWildcard(config_dims, dims)) { + // If 'dims' already matches 'config_dims', prepend with batch size 1 + res.push_back(1); + res.insert(res.end(), dims.begin(), dims.end()); + reshaped = true; + } + // Otherwise, assuming the tensor is already in the batch expected + // by the model and do nothing + } else { + // Check if the batched tensor can be sent to the non-batching + // model as one tensor. If not, strip the batch dimension if + // it is batch size 1 + if (!triton::common::CompareDimsWithWildcard(config_dims, dims) && + (tensor_batch_size == 1)) { + res.assign(dims.begin() + 1, dims.end()); + reshaped = true; + } + } + } + + if (!reshaped) { + res = dims; + } + return res; +} + +Status +EnsembleContext::FinishEnsemble(std::unique_ptr&& response) +{ + // Do nothing if the ensemble is finished + if (request_tracker_ == nullptr) { + return ensemble_status_; + } + + // Add ensemble name to make error message more trackable + if (!ensemble_status_.IsOk()) { + ensemble_status_ = Status( + ensemble_status_.StatusCode(), "in ensemble '" + info_->ensemble_name_ + + "', " + ensemble_status_.Message()); + } + + if (ensemble_status_.IsOk()) { + if (info_->is_decoupled_) { + if (response != nullptr) { + InferenceResponse::Send(std::move(response), 0 /* flags */); + } + if (inflight_step_counter_ != 0) { + return ensemble_status_; + } + request_tracker_->Request()->ResponseFactory()->SendFlags( + TRITONSERVER_RESPONSE_COMPLETE_FINAL); + } else { + InferenceResponse::Send( + std::move(response), TRITONSERVER_RESPONSE_COMPLETE_FINAL); + } + } else { + if (response != nullptr) { + InferenceResponse::SendWithStatus( + std::move(response), TRITONSERVER_RESPONSE_COMPLETE_FINAL, + ensemble_status_); + } else { + InferenceRequest::RespondIfError( + request_tracker_->Request(), ensemble_status_); + } + } + + // Reach here when the ensemble execution comes to the end, 'ensemble_status_' + // at this point is representative. + request_tracker_->SetStatus(ensemble_status_); + if (request_tracker_->DecrementCounter()) { + delete request_tracker_; + } + request_tracker_ = nullptr; + return ensemble_status_; +} + +Status +EnsembleContext::CheckAndSetEnsembleOutput( + const std::set>& updated_tensors, + std::unique_ptr* response) +{ + IterationCount iteration_count = 0; + // Check if updated tensor is one of the ensemble output and if all outputs + // have tensor of the same iteration count + bool ready = false; + auto& lrequest = request_tracker_->Request(); + const auto& requested_outputs = lrequest->ImmutableRequestedOutputs(); + for (const auto updated_tensor : updated_tensors) { + if (requested_outputs.find(updated_tensor.first) == + requested_outputs.end()) { + continue; + } + + ready = true; + iteration_count = updated_tensor.second; + for (const auto& output : requested_outputs) { + auto& tensor = tensor_data_[output].tensor_; + if (tensor.empty()) { + ready = false; + break; + } else { + // Check if other outputs have tensor with corresponding iteration count + if (tensor.find(iteration_count) == tensor.end()) { + ready = false; + break; + } + } + } + } + if (!ready) { + if (info_->is_decoupled_) { + return Status::Success; + } + return Status( + Status::Code::INVALID_ARG, + lrequest->LogRequest() + + "unexpected deadlock, at least one output is not set while no more " + "ensemble steps can be made"); + } + + RETURN_IF_ERROR(lrequest->ResponseFactory()->CreateResponse(response)); + + bool cuda_async_copy = false; + std::map releasing_tensors; + for (const auto& output_pair : info_->ensemble_output_shape_) { + if (requested_outputs.find(output_pair.first) == requested_outputs.end()) { + continue; + } + // Check if output is ready + auto& tensor_data = tensor_data_[output_pair.first]; + auto& tensor = tensor_data.tensor_[iteration_count]; + + auto shape = ReshapeTensorDims( + output_pair.second, (lrequest->BatchSize() != 0), + tensor_data.batch_size_, tensor.data_->OriginalShape()); + + InferenceResponse::Output* output; + RETURN_IF_ERROR((*response)->AddOutput( + output_pair.first, tensor.data_->DType(), shape, &output)); + + // Use the memory type of the memory block as preferred memory type + TRITONSERVER_MemoryType dst_memory_type; + int64_t dst_memory_type_id; + size_t content_size; + tensor.data_->Data()->BufferAt( + 0, &content_size, &dst_memory_type, &dst_memory_type_id); + + void* buffer; + RETURN_IF_ERROR(output->AllocateDataBuffer( + &buffer, content_size, &dst_memory_type, &dst_memory_type_id)); + + // Done with this output if 'expected_byte_size' is 0 + if (content_size == 0) { + continue; + } else if (buffer == nullptr) { + return Status( + Status::Code::INTERNAL, + "failed to allocate buffer for output '" + output_pair.first + "'"); + } + + size_t content_offset = 0; + size_t content_idx = 0; + TRITONSERVER_MemoryType src_memory_type; + int64_t src_memory_type_id; + + const char* content = tensor.data_->Data()->BufferAt( + content_idx, &content_size, &src_memory_type, &src_memory_type_id); + bool cuda_used = false; + while (content != nullptr) { + RETURN_IF_ERROR(CopyBuffer( + output_pair.first, src_memory_type, src_memory_type_id, + dst_memory_type, dst_memory_type_id, content_size, content, + ((char*)buffer) + content_offset, stream_, &cuda_used)); + cuda_async_copy |= cuda_used; + + content_offset += content_size; + content_idx++; + content = tensor.data_->Data()->BufferAt( + content_idx, &content_size, &src_memory_type, &src_memory_type_id); + } + + releasing_tensors.emplace(&tensor_data, &tensor.remaining_reference_count_); + + if (tensor.parameter_override_) { + switch (lrequest->CorrelationId().Type()) { + case InferenceRequest::SequenceId::DataType::STRING: + (*response)->AddParameter( + "sequence_id", tensor.correlation_id_.StringValue().c_str()); + break; + case InferenceRequest::SequenceId::DataType::UINT64: + (*response)->AddParameter( + "sequence_id", + (int64_t)tensor.correlation_id_.UnsignedIntValue()); + break; + default: + (*response)->AddParameter( + "sequence_id", + (int64_t)tensor.correlation_id_.UnsignedIntValue()); + break; + } + (*response)->AddParameter( + "sequence_start", + (tensor.flags_ & TRITONSERVER_REQUEST_FLAG_SEQUENCE_START) != 0); + (*response)->AddParameter( + "sequence_end", + (tensor.flags_ & TRITONSERVER_REQUEST_FLAG_SEQUENCE_END) != 0); + } + } + + if (cuda_async_copy) { +#ifdef TRITON_ENABLE_GPU + cudaStreamSynchronize(stream_); +#else + return Status( + Status::Code::INTERNAL, + "unexpected CUDA copy flag set while GPU is not supported"); +#endif // TRITON_ENABLE_GPU + } + + // Prune the tensor if it is not needed by other steps + for (auto& releasing_pair : releasing_tensors) { + if ((--(*releasing_pair.second)) == 0) { + releasing_pair.first->tensor_.erase(iteration_count); + } + } + + return Status::Success; +} + +void +EnsembleContext::ScheduleSteps( + const std::shared_ptr& context, StepList&& steps) +{ + for (auto& step : steps) { + step->ctx_ = context; + bool should_schedule = false; + // Must release lock before InferAsync to avoid deadlock, as the same thread + // will be calling request/response callbacks on cache hits, which will + // attempt to acquire the lock already held + { + std::lock_guard lock(context->mutex_); + + // Need to check the ensemble_status_ to ensure the FinishEnsemble() + // is called only once. + if (context->ensemble_status_.IsOk()) { + context->request_tracker_->IncrementCounter(); + should_schedule = true; + } + } + if (should_schedule) { + // On a successful call to InferAsync(), the step will be released by + // the response callback. When the response callback is invoked, the + // step must not own (and release) the request as the request should be + // transferred and managed by Triton core. In the case of cache hit, the + // request hasn't been transferred and can cause double-free, so moving + // the request ownership out of step here to avoid that + std::unique_ptr request = std::move(step->request_); + auto step_status = context->is_->InferAsync(request); + if (!step_status.IsOk()) { + std::lock_guard lock(context->mutex_); + context->ensemble_status_ = step_status; + // The request is not sent to server properly, shouldn't expect its + // release function get called. + context->request_tracker_->DecrementCounter(); + context->ensemble_status_ = context->FinishEnsemble(); + break; + } + } + step.release(); + } +} + +} // namespace + +Status +EnsembleScheduler::Create( + InferenceStatsAggregator* const stats_aggregator, + InferenceServer* const server, const inference::ModelConfig& config, + std::unique_ptr* scheduler) +{ + scheduler->reset(new EnsembleScheduler(stats_aggregator, server, config)); + return Status::Success; +} + +Status +EnsembleScheduler::Enqueue(std::unique_ptr& request) +{ + // Queue timer starts at the beginning of the queueing and + // scheduling process + request->CaptureQueueStartNs(); + INFER_TRACE_ACTIVITY( + request->Trace(), TRITONSERVER_TRACE_QUEUE_START, + request->QueueStartNs()); +#ifdef TRITON_ENABLE_TRACING + request->TraceInputTensors( + TRITONSERVER_TRACE_TENSOR_QUEUE_INPUT, "EnsembleScheduler Enqueue"); +#endif // TRITON_ENABLE_TRACING + + // Add additional callback to keep track of in-flight count + ++inflight_count_; + request->AddInternalReleaseCallback([this]() { --inflight_count_; }); + std::shared_ptr context(new EnsembleContext( + metric_reporter_.get(), stats_aggregator_, is_, info_.get(), request, + stream_)); + EnsembleContext::Proceed(context); + return Status::Success; +} + +EnsembleScheduler::EnsembleScheduler( + InferenceStatsAggregator* const stats_aggregator, + InferenceServer* const server, const inference::ModelConfig& config) + : stats_aggregator_(stats_aggregator), is_(server), stream_(nullptr), + inflight_count_(0) +{ +#ifdef TRITON_ENABLE_GPU + // create CUDA stream + auto cuerr = cudaStreamCreate(&stream_); + if (cuerr != cudaSuccess) { + stream_ = nullptr; + LOG_ERROR << "unable to create stream for " << config.name() << ": " + << cudaGetErrorString(cuerr); + } +#endif // TRITON_ENABLE_GPU + +#ifdef TRITON_ENABLE_METRICS + if (Metrics::Enabled()) { + MetricModelReporter::Create( + config.name(), 1, METRIC_REPORTER_ID_CPU, config.metric_tags(), + &metric_reporter_); + } +#endif // TRITON_ENABLE_METRICS + + // Set 'info_' based on 'config' + info_.reset(new EnsembleInfo()); + + info_->ensemble_name_ = config.name(); + + // This config field is filled internally for ensemble models + info_->is_decoupled_ = config.model_transaction_policy().decoupled(); + + for (const auto& input : config.input()) { + info_->tensor_to_step_.emplace(input.name(), std::set()); + if (input.optional()) { + info_->optional_inputs_.emplace(input.name()); + } + } + for (const auto& output : config.output()) { + info_->tensor_to_step_.emplace(output.name(), std::set()); + + if (output.has_reshape()) { + info_->ensemble_output_shape_[output.name()] = output.reshape().shape(); + } else { + info_->ensemble_output_shape_[output.name()] = output.dims(); + } + } + + for (const auto& element : config.ensemble_scheduling().step()) { + size_t step_idx = info_->steps_.size(); + info_->steps_.emplace_back(element.model_name(), element.model_version()); + for (const auto& pair : element.input_map()) { + auto it = info_->tensor_to_step_.find(pair.second); + if (it == info_->tensor_to_step_.end()) { + it = info_->tensor_to_step_.emplace(pair.second, std::set()) + .first; + } + it->second.insert(step_idx); + info_->steps_[step_idx].input_to_tensor_.emplace( + std::make_pair(pair.first, pair.second)); + } + + for (const auto& pair : element.output_map()) { + auto it = info_->tensor_to_step_.find(pair.second); + if (it == info_->tensor_to_step_.end()) { + it = info_->tensor_to_step_.emplace(pair.second, std::set()) + .first; + } + info_->steps_[step_idx].output_to_tensor_.emplace( + std::make_pair(pair.first, pair.second)); + + info_->tensor_to_prev_step_.emplace(pair.second, step_idx); + } + } +} + +EnsembleScheduler::~EnsembleScheduler() +{ +#ifdef TRITON_ENABLE_GPU + if (stream_ != nullptr) { + cudaError_t err = cudaStreamDestroy(stream_); + if (err != cudaSuccess) { + LOG_ERROR << "Failed to destroy cuda stream: " << cudaGetErrorString(err); + } + } +#endif // TRITON_ENABLE_GPU +} + +}} // namespace triton::core + +#endif // TRITON_ENABLE_ENSEMBLE diff --git a/3rdparty/core-r22.12/src/ensemble_scheduler.h b/3rdparty/core-r22.12/src/ensemble_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..0305982a7c306b7020305c3743c05050c43a18cf --- /dev/null +++ b/3rdparty/core-r22.12/src/ensemble_scheduler.h @@ -0,0 +1,123 @@ +// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#ifdef TRITON_ENABLE_ENSEMBLE + +#include +#include "metric_model_reporter.h" +#include "model_config.pb.h" +#include "model_config_utils.h" +#include "scheduler.h" +#include "status.h" + +#ifdef TRITON_ENABLE_GPU +#include +#endif // TRITON_ENABLE_GPU + +namespace triton { namespace core { + +#ifndef TRITON_ENABLE_GPU +using cudaStream_t = void*; +#endif // TRITON_ENABLE_GPU + +class InferenceServer; + +struct EnsembleInfo { + struct StepInfo { + StepInfo(const std::string& model_name, const int64_t model_version) + : model_name_(model_name), model_version_(model_version) + { + } + + std::string model_name_; + int64_t model_version_; + std::unordered_map input_to_tensor_; + std::unordered_map output_to_tensor_; + }; + + std::string ensemble_name_; + + bool is_decoupled_; + + // the ensemble output (re)shape expected by the ensemble + std::unordered_map + ensemble_output_shape_; + + // Inputs that is marked optional for the ensemble + std::set optional_inputs_; + + std::vector steps_; + + // Only include a step if the ensemble tensor is used as input in that step + std::unordered_map> tensor_to_step_; + + // backward path, ensemble tensor to the step that provides its data + std::unordered_map tensor_to_prev_step_; +}; + +// Scheduler that implements ensemble scheduling. +class EnsembleScheduler : public Scheduler { + public: + // Create a scheduler to process ensemble requests and + // to dispatch requests to models in ensemble internally. + static Status Create( + InferenceStatsAggregator* const stats_aggregator, + InferenceServer* const server, const inference::ModelConfig& config, + std::unique_ptr* scheduler); + + ~EnsembleScheduler(); + + // \see Scheduler::Enqueue() + Status Enqueue(std::unique_ptr& request) override; + + // \see Scheduler::InflightInferenceCount() + size_t InflightInferenceCount() override { return inflight_count_; } + + // \see Scheduler::Stop() + void Stop() override {} + + private: + EnsembleScheduler( + InferenceStatsAggregator* const stats_aggregator, + InferenceServer* const server, const inference::ModelConfig& config); + + std::shared_ptr metric_reporter_; + InferenceStatsAggregator* const stats_aggregator_; + InferenceServer* const is_; + + // Ensemble information that is built from model config + std::unique_ptr info_; + + // The stream used for data transfer. + cudaStream_t stream_; + + std::atomic inflight_count_; +}; + +}} // namespace triton::core + +#endif // TRITON_ENABLE_ENSEMBLE diff --git a/3rdparty/core-r22.12/src/ensemble_utils.cc b/3rdparty/core-r22.12/src/ensemble_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..dc4b6c5e853fdbf9e98c902cae59e1572b171860 --- /dev/null +++ b/3rdparty/core-r22.12/src/ensemble_utils.cc @@ -0,0 +1,370 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifdef TRITON_ENABLE_ENSEMBLE + +#include "ensemble_utils.h" + +#include +#include "constants.h" +#include "model.h" +#include "model_config_utils.h" +#include "triton/common/logging.h" + +namespace triton { namespace core { + +namespace { + +/// A basic unit in ensemble graph that records the data type and shape +/// of the ensemble tensor and which model they are inferred from. +struct TensorNode { + TensorNode( + const std::string& model_name, const bool batching, + const inference::DataType& type, const triton::common::DimsList& dims) + : model_name_(model_name), type_(type), dims_(dims), is_decoupled_(false), + decouple_label_(0), visited_(false) + { + // Expand dims to full shape, which includes batch dimension if exist + if (batching) { + full_dims_.Add(-1); + } + full_dims_.MergeFrom(dims_); + } + + // Constructor for symbolic nodes + TensorNode(const std::string& model_name) + : model_name_(model_name), is_decoupled_(false), decouple_label_(0), + visited_(false) + { + } + + std::string model_name_; + inference::DataType type_; + triton::common::DimsList dims_; + triton::common::DimsList full_dims_; + bool is_decoupled_; + size_t decouple_label_; + bool visited_; + std::vector prev_nodes_; + std::vector next_nodes_; + // A symbolic node to keep track of the decouple label of nodes that + // are outputs of the same step. + std::shared_ptr sibling_node_; +}; + +/// Validate if the data type and the shape of two TensorNode object are +/// consistent. +/// \param lhs One of the TensorNode object to be validated. +/// \param rhs Another TensorNode object to be validated. +/// \param message Extra message included in the front of error message +/// if error status is non-OK. +/// \return The error status. A non-OK status indicates the TensorNode objects +/// are not consistent. +Status +ValidateTensorConsistency( + const TensorNode& lhs, const TensorNode& rhs, const std::string& message) +{ + if (lhs.type_ != rhs.type_) { + return Status( + Status::Code::INVALID_ARG, + message + "inconsistent data type: " + + inference::DataType_Name(lhs.type_) + " is inferred from model " + + lhs.model_name_ + " while " + inference::DataType_Name(rhs.type_) + + " is inferred from model " + rhs.model_name_); + } + + // Shapes must match or either one uses variable size shape, if one uses + // variable size shape, shape consistency will be checked at runtime. + // If dims mismatch, compare agian with full dims in case the tensor is + // used for both non-batching model and batching model. In that case, it + // is acceptable if non-batching model shape is [-1, d_0, d_1, ..., d_n] + // while the batching model shape is [d_0, d_1, ..., d_n]. + if (!triton::common::CompareDimsWithWildcard(lhs.dims_, rhs.dims_) && + !triton::common::CompareDimsWithWildcard( + lhs.full_dims_, rhs.full_dims_)) { + return Status( + Status::Code::INVALID_ARG, + message + "inconsistent shape: " + + triton::common::DimsListToString(lhs.full_dims_) + + " is inferred from model " + lhs.model_name_ + " while " + + triton::common::DimsListToString(rhs.full_dims_) + + " is inferred from model " + rhs.model_name_); + } + + return Status::Success; +} + +Status +ValidateTensorMapping( + const std::string& ensemble, const inference::ModelEnsembling::Step& step, + const inference::ModelConfig& model_config, + std::unordered_map* ensemble_tensors) +{ + const bool batching = (model_config.max_batch_size() > 0); + // Check all inputs are mapped and no mapping to invalid inputs + std::set input_names; + for (const auto& model_input : model_config.input()) { + input_names.insert(model_input.name()); + } + for (const auto& input_map : step.input_map()) { + if (input_names.find(input_map.first) == input_names.end()) { + return Status( + Status::Code::INVALID_ARG, + "in ensemble " + ensemble + ", ensemble tensor " + input_map.second + + " is mapping to non-existing input " + input_map.first + + " in model " + step.model_name()); + } + } + for (const auto& model_input : model_config.input()) { + size_t mapped_cnt = 0; + for (const auto& input_map : step.input_map()) { + if (model_input.name() == input_map.first) { + TensorNode model_tensor( + step.model_name(), batching, model_input.data_type(), + model_input.dims()); + auto it = ensemble_tensors->find(input_map.second); + if (it != ensemble_tensors->end()) { + RETURN_IF_ERROR(ValidateTensorConsistency( + it->second, model_tensor, + "in ensemble " + ensemble + ", ensemble tensor " + + input_map.second + ": ")); + } else { + ensemble_tensors->emplace( + std::make_pair(input_map.second, model_tensor)); + } + mapped_cnt++; + } + } + if (mapped_cnt == 0) { + // Allow the input to be excluded from ensemble if it is optional + if (model_input.optional()) { + continue; + } + return Status( + Status::Code::INVALID_ARG, + "in ensemble " + ensemble + ", input " + model_input.name() + + " in model " + model_config.name() + + " is not mapped to any ensemble tensors"); + } else if (mapped_cnt > 1) { + return Status( + Status::Code::INVALID_ARG, + "in ensemble " + ensemble + ", input " + model_input.name() + + " in model " + model_config.name() + + " is mapped to multiple ensemble tensors"); + } + } + + // Check no multiple mappings to same ensemble tensor + // and no mapping from invalid outputs + std::set output_names; + for (const auto& model_output : model_config.output()) { + output_names.insert(model_output.name()); + } + for (const auto& output_map : step.output_map()) { + if (output_names.find(output_map.first) == output_names.end()) { + return Status( + Status::Code::INVALID_ARG, + "in ensemble " + ensemble + ", ensemble tensor " + output_map.second + + " is mapped from non-existing output " + output_map.first + + " in model " + step.model_name()); + } + } + std::shared_ptr sibling_node(new TensorNode(step.model_name())); + for (const auto& output_map : step.output_map()) { + size_t mapped_cnt = 0; + for (const auto& model_output : model_config.output()) { + if (model_output.name() == output_map.first) { + TensorNode model_tensor( + step.model_name(), batching, model_output.data_type(), + model_output.dims()); + auto it = ensemble_tensors->find(output_map.second); + if (it != ensemble_tensors->end()) { + RETURN_IF_ERROR(ValidateTensorConsistency( + it->second, model_tensor, + "in ensemble " + ensemble + ", ensemble tensor " + + output_map.second + ": ")); + } else { + it = ensemble_tensors + ->emplace(std::make_pair(output_map.second, model_tensor)) + .first; + } + it->second.sibling_node_ = sibling_node; + mapped_cnt++; + } + } + if (mapped_cnt > 1) { + return Status( + Status::Code::INVALID_ARG, + "in ensemble " + ensemble + ", multiple outputs in model " + + model_config.name() + " are mapped to the same ensemble tensor " + + output_map.second); + } + } + + // link ensemble tensors + bool is_decoupled = model_config.model_transaction_policy().decoupled(); + for (const auto& output_map : step.output_map()) { + auto& node = ensemble_tensors->find(output_map.second)->second; + node.is_decoupled_ = is_decoupled; + for (const auto& input_map : step.input_map()) { + auto& prev_node = ensemble_tensors->find(input_map.second)->second; + node.prev_nodes_.push_back(&prev_node); + prev_node.next_nodes_.push_back(&node); + } + } + return Status::Success; +} + +} // namespace + +Status +ValidateEnsembleConfig( + ModelRepositoryManager* model_repository_manager, + ModelRepositoryManager::DependencyNode* ensemble) +{ + const auto& ensemble_config = ensemble->model_config_; + if (!ensemble_config.has_ensemble_scheduling()) { + return Status::Success; + } + + const auto& ensemble_name = ensemble->model_name_; + const bool batching = (ensemble_config.max_batch_size() > 0); + std::unordered_map ensemble_tensors; + for (const auto& input : ensemble_config.input()) { + const auto& dims = + input.has_reshape() ? input.reshape().shape() : input.dims(); + TensorNode input_node(ensemble_name, batching, input.data_type(), dims); + ensemble_tensors.emplace(std::make_pair(input.name(), input_node)); + } + + TensorNode sink_node(ensemble_name); + for (const auto& output : ensemble_config.output()) { + const auto& dims = + output.has_reshape() ? output.reshape().shape() : output.dims(); + TensorNode output_node(ensemble_name, batching, output.data_type(), dims); + auto it = + ensemble_tensors.emplace(std::make_pair(output.name(), output_node)) + .first; + sink_node.prev_nodes_.emplace_back(&(it->second)); + it->second.next_nodes_.emplace_back(&sink_node); + } + + for (const auto& step : ensemble_config.ensemble_scheduling().step()) { + const auto& model_name = step.model_name(); + inference::ModelConfig model_config; + for (auto& node : ensemble->upstreams_) { + if (model_name == node.first->model_name_) { + // Obtain completed config from model instance + std::shared_ptr model; + RETURN_IF_ERROR( + model_repository_manager->GetModel(model_name, -1, &model)); + model_config = model->Config(); + break; + } + } + + // batchable ensemble can include non-batchable models as long as + // the expanded shapes are consistent + if ((model_config.max_batch_size() != 0) && + (model_config.max_batch_size() < ensemble_config.max_batch_size())) { + return Status( + Status::Code::INVALID_ARG, + "ensemble " + ensemble_name + " allows maximum batch size " + + std::to_string(ensemble_config.max_batch_size()) + + ", but it contains model " + model_name + + " which only allows maximum batch size to be " + + std::to_string(model_config.max_batch_size())); + } + + RETURN_IF_ERROR(ValidateTensorMapping( + ensemble_name, step, model_config, &ensemble_tensors)); + } + + // Visit nodes and validate decoupled workflow if any + // check data flow + size_t decouple_label = 0; + std::deque current_iterators; + for (const auto& input : ensemble_config.input()) { + auto it = ensemble_tensors.find(input.name()); + it->second.visited_ = true; + current_iterators.push_back(&(it->second)); + } + while (!current_iterators.empty()) { + auto& current_node = current_iterators.front(); + for (auto& next_node : current_node->next_nodes_) { + if (next_node->visited_) { + continue; + } + bool next_node_ready = true; + for (auto& prev_node : next_node->prev_nodes_) { + if (!prev_node->visited_) { + next_node_ready = false; + break; + } + } + if (next_node_ready) { + size_t prev_decouple_label = next_node->prev_nodes_[0]->decouple_label_; + for (auto& prev_node : next_node->prev_nodes_) { + if (prev_node->decouple_label_ != prev_decouple_label) { + return Status( + Status::Code::INVALID_ARG, + "in ensemble " + ensemble_name + ", step of model '" + + next_node->model_name_ + + "' receives inputs originated from different decoupled " + "models"); + } + } + if (next_node->sibling_node_ != nullptr) { + if (next_node->sibling_node_->visited_) { + next_node->decouple_label_ = + next_node->sibling_node_->decouple_label_; + } else { + next_node->decouple_label_ = next_node->is_decoupled_ + ? ++decouple_label + : prev_decouple_label; + next_node->sibling_node_->decouple_label_ = + next_node->decouple_label_; + next_node->sibling_node_->visited_ = true; + } + } else { + next_node->decouple_label_ = + next_node->is_decoupled_ ? ++decouple_label : prev_decouple_label; + } + next_node->visited_ = true; + current_iterators.push_back(next_node); + } + } + current_iterators.pop_front(); + } + ensemble->model_config_.mutable_model_transaction_policy()->set_decoupled( + decouple_label != 0); + + return Status::Success; +} + +}} // namespace triton::core + +#endif // TRITON_ENABLE_ENSEMBLE diff --git a/3rdparty/core-r22.12/src/ensemble_utils.h b/3rdparty/core-r22.12/src/ensemble_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..63a9afa85eba94fd2fc8be40bc0a0edac556011b --- /dev/null +++ b/3rdparty/core-r22.12/src/ensemble_utils.h @@ -0,0 +1,50 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#ifdef TRITON_ENABLE_ENSEMBLE + +#include +#include +#include "model_config.pb.h" +#include "model_repository_manager.h" +#include "status.h" +#include "triton/common/model_config.h" + +namespace triton { namespace core { + +/// Validate that the ensemble are specified correctly. Assuming that the +/// inputs and outputs specified in depending model configurations are accurate. +/// \param model_repository_manager The model manager to acquire model config. +/// \param ensemble The ensemble to be validated. +/// \return The error status. +Status ValidateEnsembleConfig( + ModelRepositoryManager* model_repository_manager, + ModelRepositoryManager::DependencyNode* ensemble); + +}} // namespace triton::core + +#endif // TRITON_ENABLE_ENSEMBLE diff --git a/3rdparty/core-r22.12/src/filesystem.cc b/3rdparty/core-r22.12/src/filesystem.cc new file mode 100644 index 0000000000000000000000000000000000000000..e0e2f98e70735635ce6aa80d5d3e6a62f8d2ccb4 --- /dev/null +++ b/3rdparty/core-r22.12/src/filesystem.cc @@ -0,0 +1,2662 @@ +// Copyright 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "filesystem.h" + +#ifdef _WIN32 +// suppress the min and max definitions in Windef.h. +#define NOMINMAX +#include + +// _CRT_INTERNAL_NONSTDC_NAMES 1 before including Microsoft provided C Runtime +// library to expose declarations without "_" prefix to match POSIX style. +#define _CRT_INTERNAL_NONSTDC_NAMES 1 +#include +#include +#else +#include +#include +#endif + +#ifdef TRITON_ENABLE_GCS +#include +#endif // TRITON_ENABLE_GCS + +#ifdef TRITON_ENABLE_S3 +#include +#include +#include +#include +#include +#include +#include +#endif // TRITON_ENABLE_S3 + +#ifdef TRITON_ENABLE_AZURE_STORAGE +#include +#include +#include +#undef LOG_INFO +#undef LOG_WARNING +#endif // TRITON_ENABLE_AZURE_STORAGE + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "constants.h" +#include "status.h" +#include "triton/common/logging.h" + +#define TRITONJSON_STATUSTYPE triton::core::Status +#define TRITONJSON_STATUSRETURN(M) \ + return triton::core::Status(triton::core::Status::Code::INTERNAL, (M)) +#define TRITONJSON_STATUSSUCCESS triton::core::Status::Success +#include "triton/common/triton_json.h" + +#ifdef _WIN32 +// in Windows doesn't define S_ISDIR macro +#if !defined(S_ISDIR) && defined(S_IFMT) && defined(S_IFDIR) +#define S_ISDIR(m) (((m)&S_IFMT) == S_IFDIR) +#endif +#define F_OK 0 +#endif + +namespace triton { namespace core { + +namespace { + +// Check if a local path is a directory. We need to use this in LocalFileSystem +// and LocalizedPath so have this common function. +Status +IsPathDirectory(const std::string& path, bool* is_dir) +{ + *is_dir = false; + + struct stat st; + if (stat(path.c_str(), &st) != 0) { + return Status(Status::Code::INTERNAL, "failed to stat file " + path); + } + + *is_dir = S_ISDIR(st.st_mode); + return Status::Success; +} + +} // namespace + +LocalizedPath::~LocalizedPath() +{ + if (!local_path_.empty()) { + bool is_dir = true; + IsDirectory(local_path_, &is_dir); + LOG_STATUS_ERROR( + DeletePath(is_dir ? local_path_ : DirName(local_path_)), + "failed to delete localized path"); + } +} + +namespace { + +class FileSystem { + public: + virtual Status FileExists(const std::string& path, bool* exists) = 0; + virtual Status IsDirectory(const std::string& path, bool* is_dir) = 0; + virtual Status FileModificationTime( + const std::string& path, int64_t* mtime_ns) = 0; + virtual Status GetDirectoryContents( + const std::string& path, std::set* contents) = 0; + virtual Status GetDirectorySubdirs( + const std::string& path, std::set* subdirs) = 0; + virtual Status GetDirectoryFiles( + const std::string& path, std::set* files) = 0; + virtual Status ReadTextFile( + const std::string& path, std::string* contents) = 0; + virtual Status LocalizePath( + const std::string& path, std::shared_ptr* localized) = 0; + virtual Status WriteTextFile( + const std::string& path, const std::string& contents) = 0; + virtual Status WriteBinaryFile( + const std::string& path, const char* contents, + const size_t content_len) = 0; + virtual Status MakeDirectory( + const std::string& dir, const bool recursive) = 0; + virtual Status MakeTemporaryDirectory(std::string* temp_dir) = 0; + virtual Status DeletePath(const std::string& path) = 0; +}; + +class LocalFileSystem : public FileSystem { + public: + Status FileExists(const std::string& path, bool* exists) override; + Status IsDirectory(const std::string& path, bool* is_dir) override; + Status FileModificationTime( + const std::string& path, int64_t* mtime_ns) override; + Status GetDirectoryContents( + const std::string& path, std::set* contents) override; + Status GetDirectorySubdirs( + const std::string& path, std::set* subdirs) override; + Status GetDirectoryFiles( + const std::string& path, std::set* files) override; + Status ReadTextFile(const std::string& path, std::string* contents) override; + Status LocalizePath( + const std::string& path, + std::shared_ptr* localized) override; + Status WriteTextFile( + const std::string& path, const std::string& contents) override; + Status WriteBinaryFile( + const std::string& path, const char* contents, + const size_t content_len) override; + Status MakeDirectory(const std::string& dir, const bool recursive) override; + Status MakeTemporaryDirectory(std::string* temp_dir) override; + Status DeletePath(const std::string& path) override; +}; + +Status +LocalFileSystem::FileExists(const std::string& path, bool* exists) +{ + *exists = (access(path.c_str(), F_OK) == 0); + return Status::Success; +} + +Status +LocalFileSystem::IsDirectory(const std::string& path, bool* is_dir) +{ + return IsPathDirectory(path, is_dir); +} + +Status +LocalFileSystem::FileModificationTime( + const std::string& path, int64_t* mtime_ns) +{ + struct stat st; + if (stat(path.c_str(), &st) != 0) { + return Status(Status::Code::INTERNAL, "failed to stat file " + path); + } + +#ifdef _WIN32 + // In Windows, st_mtime is in time_t + *mtime_ns = std::max(st.st_mtime, st.st_ctime); +#else + *mtime_ns = + std::max(TIMESPEC_TO_NANOS(st.st_mtim), TIMESPEC_TO_NANOS(st.st_ctim)); +#endif + return Status::Success; +} + +Status +LocalFileSystem::GetDirectoryContents( + const std::string& path, std::set* contents) +{ +#ifdef _WIN32 + WIN32_FIND_DATA entry; + // Append "*" to obtain all files under 'path' + HANDLE dir = FindFirstFile(JoinPath({path, "*"}).c_str(), &entry); + if (dir == INVALID_HANDLE_VALUE) { + return Status(Status::Code::INTERNAL, "failed to open directory " + path); + } + if ((strcmp(entry.cFileName, ".") != 0) && + (strcmp(entry.cFileName, "..") != 0)) { + contents->insert(entry.cFileName); + } + while (FindNextFile(dir, &entry)) { + if ((strcmp(entry.cFileName, ".") != 0) && + (strcmp(entry.cFileName, "..") != 0)) { + contents->insert(entry.cFileName); + } + } + + FindClose(dir); +#else + DIR* dir = opendir(path.c_str()); + if (dir == nullptr) { + return Status(Status::Code::INTERNAL, "failed to open directory " + path); + } + + struct dirent* entry; + while ((entry = readdir(dir)) != nullptr) { + std::string entryname = entry->d_name; + if ((entryname != ".") && (entryname != "..")) { + contents->insert(entryname); + } + } + + closedir(dir); +#endif + return Status::Success; +} + +Status +LocalFileSystem::GetDirectorySubdirs( + const std::string& path, std::set* subdirs) +{ + RETURN_IF_ERROR(GetDirectoryContents(path, subdirs)); + + // Erase non-directory entries... + for (auto iter = subdirs->begin(); iter != subdirs->end();) { + bool is_dir; + RETURN_IF_ERROR(IsDirectory(JoinPath({path, *iter}), &is_dir)); + if (!is_dir) { + iter = subdirs->erase(iter); + } else { + ++iter; + } + } + + return Status::Success; +} + +Status +LocalFileSystem::GetDirectoryFiles( + const std::string& path, std::set* files) +{ + RETURN_IF_ERROR(GetDirectoryContents(path, files)); + + // Erase directory entries... + for (auto iter = files->begin(); iter != files->end();) { + bool is_dir; + RETURN_IF_ERROR(IsDirectory(JoinPath({path, *iter}), &is_dir)); + if (is_dir) { + iter = files->erase(iter); + } else { + ++iter; + } + } + + return Status::Success; +} + +Status +LocalFileSystem::ReadTextFile(const std::string& path, std::string* contents) +{ + std::ifstream in(path, std::ios::in | std::ios::binary); + if (!in) { + return Status( + Status::Code::INTERNAL, + "failed to open text file for read " + path + ": " + strerror(errno)); + } + + in.seekg(0, std::ios::end); + contents->resize(in.tellg()); + in.seekg(0, std::ios::beg); + in.read(&(*contents)[0], contents->size()); + in.close(); + + return Status::Success; +} + +Status +LocalFileSystem::LocalizePath( + const std::string& path, std::shared_ptr* localized) +{ + // For local file system we don't actually need to download the + // directory or file. We use it in place. + localized->reset(new LocalizedPath(path)); + return Status::Success; +} + +Status +LocalFileSystem::WriteTextFile( + const std::string& path, const std::string& contents) +{ + std::ofstream out(path, std::ios::out | std::ios::binary); + if (!out) { + return Status( + Status::Code::INTERNAL, + "failed to open text file for write " + path + ": " + strerror(errno)); + } + + out.write(&contents[0], contents.size()); + out.close(); + + return Status::Success; +} + +Status +LocalFileSystem::WriteBinaryFile( + const std::string& path, const char* contents, const size_t content_len) +{ + std::ofstream out(path, std::ios::out | std::ios::binary); + if (!out) { + return Status( + Status::Code::INTERNAL, "failed to open binary file for write " + path + + ": " + strerror(errno)); + } + + out.write(contents, content_len); + + return Status::Success; +} + +Status +LocalFileSystem::MakeDirectory(const std::string& dir, const bool recursive) +{ +#ifdef _WIN32 + if (mkdir(dir.c_str()) == -1) +#else + if (mkdir(dir.c_str(), S_IRWXU) == -1) +#endif + { + // Only allow the error due to parent directory does not exist + // if 'recursive' is requested + if ((errno == ENOENT) && (!dir.empty()) && recursive) { + RETURN_IF_ERROR(MakeDirectory(DirName(dir), recursive)); + // Retry the creation +#ifdef _WIN32 + if (mkdir(dir.c_str()) == -1) +#else + if (mkdir(dir.c_str(), S_IRWXU) == -1) +#endif + { + return Status( + Status::Code::INTERNAL, "Failed to create directory '" + dir + + "', errno:" + strerror(errno)); + } + } else { + return Status( + Status::Code::INTERNAL, + "Failed to create directory '" + dir + "', errno:" + strerror(errno)); + } + } + + return Status::Success; +} + +Status +LocalFileSystem::MakeTemporaryDirectory(std::string* temp_dir) +{ +#ifdef _WIN32 + char temp_path[MAX_PATH + 1]; + size_t temp_path_length = GetTempPath(MAX_PATH + 1, temp_path); + if (temp_path_length == 0) { + return Status( + Status::Code::INTERNAL, + "Failed to get local directory for temporary files"); + } + // There is no single operation like 'mkdtemp' in Windows, thus generating + // unique temporary directory is a process of getting temporary file name, + // deleting the file (file creation is side effect fo getting name), creating + // corresponding directory, so mutex is used to avoid possible race condition. + // However, it doesn't prevent other process on creating temporary file and + // thus the race condition may still happen. One possible solution is + // to reserve a temporary directory for the process and generate temporary + // model directories inside it. + static std::mutex mtx; + std::lock_guard lk(mtx); + // Construct a std::string as filled 'temp_path' is not C string, + // and so that we can reuse 'temp_path' to hold the temp file name. + std::string temp_path_str(temp_path, temp_path_length); + if (GetTempFileName(temp_path_str.c_str(), "folder", 0, temp_path) == 0) { + return Status(Status::Code::INTERNAL, "Failed to create local temp folder"); + } + *temp_dir = temp_path; + DeleteFile(temp_dir->c_str()); + if (CreateDirectory(temp_dir->c_str(), NULL) == 0) { + return Status( + Status::Code::INTERNAL, + "Failed to create local temp folder: " + *temp_dir); + } +#else + std::string folder_template = "/tmp/folderXXXXXX"; + char* res = mkdtemp(const_cast(folder_template.c_str())); + if (res == nullptr) { + return Status( + Status::Code::INTERNAL, + "Failed to create local temp folder: " + folder_template + + ", errno:" + strerror(errno)); + } + *temp_dir = res; +#endif + return Status::Success; +} + +Status +LocalFileSystem::DeletePath(const std::string& path) +{ + bool is_dir = false; + RETURN_IF_ERROR(IsDirectory(path, &is_dir)); + if (is_dir) { + std::set contents; + RETURN_IF_ERROR(GetDirectoryContents(path, &contents)); + for (const auto& content : contents) { + RETURN_IF_ERROR(DeletePath(JoinPath({path, content}))); + } + rmdir(path.c_str()); + } else { + remove(path.c_str()); + } + return Status::Success; +} + +#if defined(TRITON_ENABLE_GCS) || defined(TRITON_ENABLE_S3) || \ + defined(TRITON_ENABLE_AZURE_STORAGE) +// Helper function to take care of lack of trailing slashes +std::string +AppendSlash(const std::string& name) +{ + if (name.empty() || (name.back() == '/')) { + return name; + } + + return (name + "/"); +} +#endif // TRITON_ENABLE_GCS || TRITON_ENABLE_S3 || TRITON_ENABLE_AZURE_STORAGE + +#ifdef TRITON_ENABLE_GCS + +namespace gcs = google::cloud::storage; + +struct GCSCredential { + std::string path_; + + GCSCredential(); // from env var + GCSCredential(triton::common::TritonJson::Value& cred_json); +}; + +GCSCredential::GCSCredential() +{ + const char* path = std::getenv("GOOGLE_APPLICATION_CREDENTIALS"); + path_ = (path != nullptr ? std::string(path) : ""); +} + +GCSCredential::GCSCredential(triton::common::TritonJson::Value& cred_json) +{ + cred_json.AsString(&path_); +} + +class GCSFileSystem : public FileSystem { + public: + GCSFileSystem(const GCSCredential& gs_cred); + // unify with S3/azure interface + GCSFileSystem(const std::string& path, const GCSCredential& gs_cred) + : GCSFileSystem(gs_cred) + { + } + Status CheckClient(); + // unify with S3 interface + Status CheckClient(const std::string& path) { return CheckClient(); } + + Status FileExists(const std::string& path, bool* exists) override; + Status IsDirectory(const std::string& path, bool* is_dir) override; + Status FileModificationTime( + const std::string& path, int64_t* mtime_ns) override; + Status GetDirectoryContents( + const std::string& path, std::set* contents) override; + Status GetDirectorySubdirs( + const std::string& path, std::set* subdirs) override; + Status GetDirectoryFiles( + const std::string& path, std::set* files) override; + Status ReadTextFile(const std::string& path, std::string* contents) override; + Status LocalizePath( + const std::string& path, + std::shared_ptr* localized) override; + Status WriteTextFile( + const std::string& path, const std::string& contents) override; + Status WriteBinaryFile( + const std::string& path, const char* contents, + const size_t content_len) override; + Status MakeDirectory(const std::string& dir, const bool recursive) override; + Status MakeTemporaryDirectory(std::string* temp_dir) override; + Status DeletePath(const std::string& path) override; + + private: + Status ParsePath( + const std::string& path, std::string* bucket, std::string* object); + Status MetaDataExists( + const std::string path, bool* exists, + google::cloud::StatusOr* metadata); + + google::cloud::StatusOr client_; +}; + +GCSFileSystem::GCSFileSystem(const GCSCredential& gs_cred) +{ + auto creds = gcs::oauth2::CreateServiceAccountCredentialsFromJsonFilePath( + gs_cred.path_); + if (creds) { + client_ = gcs::Client(gcs::ClientOptions(*creds)); + } +} + +Status +GCSFileSystem::CheckClient() +{ + if (!client_) { + return Status( + Status::Code::INTERNAL, + "Unable to create GCS client. Check account credentials."); + } + return Status::Success; +} + +Status +GCSFileSystem::ParsePath( + const std::string& path, std::string* bucket, std::string* object) +{ + // Get the bucket name and the object path. Return error if input is malformed + int bucket_start = path.find("gs://") + strlen("gs://"); + int bucket_end = path.find("/", bucket_start); + + // If there isn't a second slash, the address has only the bucket + if (bucket_end > bucket_start) { + *bucket = path.substr(bucket_start, bucket_end - bucket_start); + *object = path.substr(bucket_end + 1); + } else { + *bucket = path.substr(bucket_start); + *object = ""; + } + + if (bucket->empty()) { + return Status( + Status::Code::INTERNAL, "No bucket name found in path: " + path); + } + + return Status::Success; +} + +Status +GCSFileSystem::FileExists(const std::string& path, bool* exists) +{ + *exists = false; + + std::string bucket, object; + RETURN_IF_ERROR(ParsePath(path, &bucket, &object)); + + // Make a request for metadata and check the response + google::cloud::StatusOr object_metadata = + client_->GetObjectMetadata(bucket, object); + + if (object_metadata) { + *exists = true; + return Status::Success; + } + + // GCS doesn't make objects for directories, so it could still be a directory + bool is_dir; + RETURN_IF_ERROR(IsDirectory(path, &is_dir)); + *exists = is_dir; + + return Status::Success; +} + +Status +GCSFileSystem::IsDirectory(const std::string& path, bool* is_dir) +{ + *is_dir = false; + std::string bucket, object_path; + RETURN_IF_ERROR(ParsePath(path, &bucket, &object_path)); + + // Check if the bucket exists + google::cloud::StatusOr bucket_metadata = + client_->GetBucketMetadata(bucket); + + if (!bucket_metadata) { + return Status( + Status::Code::INTERNAL, "Could not get MetaData for bucket with name " + + bucket + " : " + + bucket_metadata.status().message()); + } + + // Root case - bucket exists and object path is empty + if (object_path.empty()) { + *is_dir = true; + return Status::Success; + } + + // Check whether it has children. If at least one child, it is a directory + for (auto&& object_metadata : + client_->ListObjects(bucket, gcs::Prefix(AppendSlash(object_path)))) { + if (object_metadata) { + *is_dir = true; + break; + } + } + return Status::Success; +} + +Status +GCSFileSystem::FileModificationTime(const std::string& path, int64_t* mtime_ns) +{ + // We don't need to worry about the case when this is a directory + bool is_dir; + RETURN_IF_ERROR(IsDirectory(path, &is_dir)); + if (is_dir) { + *mtime_ns = 0; + return Status::Success; + } + + std::string bucket, object; + RETURN_IF_ERROR(ParsePath(path, &bucket, &object)); + + // Otherwise check the object metadata for update time + google::cloud::StatusOr object_metadata = + client_->GetObjectMetadata(bucket, object); + + if (!object_metadata) { + return Status( + Status::Code::INTERNAL, "Failed to get metadata for " + object + " : " + + object_metadata.status().message()); + } + + // Get duration from time point with respect to object clock + auto update_time = std::chrono::time_point_cast( + object_metadata->updated()) + .time_since_epoch() + .count(); + + *mtime_ns = update_time; + return Status::Success; +} + +Status +GCSFileSystem::GetDirectoryContents( + const std::string& path, std::set* contents) +{ + std::string bucket, dir_path; + RETURN_IF_ERROR(ParsePath(path, &bucket, &dir_path)); + // Append a slash to make it easier to list contents + std::string full_dir = AppendSlash(dir_path); + + // Get objects with prefix equal to full directory path + for (auto&& object_metadata : + client_->ListObjects(bucket, gcs::Prefix(full_dir))) { + if (!object_metadata) { + return Status( + Status::Code::INTERNAL, "Could not list contents of directory at " + + path + " : " + + object_metadata.status().message()); + } + + // In the case of empty directories, the directory itself will appear here + if (object_metadata->name() == full_dir) { + continue; + } + + // We have to make sure that subdirectory contents do not appear here + std::string name = object_metadata->name(); + int item_start = name.find(full_dir) + full_dir.size(); + // GCS response prepends parent directory name + int item_end = name.find("/", item_start); + + // Let set take care of subdirectory contents + std::string item = name.substr(item_start, item_end - item_start); + contents->insert(item); + } + return Status::Success; +} + +Status +GCSFileSystem::GetDirectorySubdirs( + const std::string& path, std::set* subdirs) +{ + RETURN_IF_ERROR(GetDirectoryContents(path, subdirs)); + + // Erase non-directory entries... + for (auto iter = subdirs->begin(); iter != subdirs->end();) { + bool is_dir; + RETURN_IF_ERROR(IsDirectory(JoinPath({path, *iter}), &is_dir)); + if (!is_dir) { + iter = subdirs->erase(iter); + } else { + ++iter; + } + } + + return Status::Success; +} + +Status +GCSFileSystem::GetDirectoryFiles( + const std::string& path, std::set* files) +{ + RETURN_IF_ERROR(GetDirectoryContents(path, files)); + + // Erase directory entries... + for (auto iter = files->begin(); iter != files->end();) { + bool is_dir; + RETURN_IF_ERROR(IsDirectory(JoinPath({path, *iter}), &is_dir)); + if (is_dir) { + iter = files->erase(iter); + } else { + ++iter; + } + } + + return Status::Success; +} + +Status +GCSFileSystem::ReadTextFile(const std::string& path, std::string* contents) +{ + bool exists; + RETURN_IF_ERROR(FileExists(path, &exists)); + + if (!exists) { + return Status(Status::Code::INTERNAL, "File does not exist at " + path); + } + + std::string bucket, object; + ParsePath(path, &bucket, &object); + + gcs::ObjectReadStream stream = client_->ReadObject(bucket, object); + + if (!stream) { + return Status( + Status::Code::INTERNAL, "Failed to open object read stream for " + + path + " : " + stream.status().message()); + } + + std::string data = ""; + char c; + while (stream.get(c)) { + data += c; + } + + *contents = data; + + return Status::Success; +} + +Status +GCSFileSystem::LocalizePath( + const std::string& path, std::shared_ptr* localized) +{ + bool exists; + RETURN_IF_ERROR(FileExists(path, &exists)); + if (!exists) { + return Status( + Status::Code::INTERNAL, "directory or file does not exist at " + path); + } + + bool is_dir; + RETURN_IF_ERROR(IsDirectory(path, &is_dir)); + if (!is_dir) { + return Status( + Status::Code::UNSUPPORTED, + "GCS file localization not yet implemented " + path); + } + + std::string tmp_folder; + RETURN_IF_ERROR( + triton::core::MakeTemporaryDirectory(FileSystemType::LOCAL, &tmp_folder)); + + localized->reset(new LocalizedPath(path, tmp_folder)); + + std::set contents, filenames; + RETURN_IF_ERROR(GetDirectoryContents(path, &filenames)); + for (auto itr = filenames.begin(); itr != filenames.end(); ++itr) { + contents.insert(JoinPath({path, *itr})); + } + + while (contents.size() != 0) { + std::set tmp_contents = contents; + contents.clear(); + for (auto iter = tmp_contents.begin(); iter != tmp_contents.end(); ++iter) { + bool is_subdir; + std::string gcs_fpath = *iter; + std::string gcs_removed_path = gcs_fpath.substr(path.size()); + std::string local_fpath = + JoinPath({(*localized)->Path(), gcs_removed_path}); + RETURN_IF_ERROR(IsDirectory(gcs_fpath, &is_subdir)); + if (is_subdir) { + // Create local mirror of sub-directories +#ifdef _WIN32 + int status = mkdir(const_cast(local_fpath.c_str())); +#else + int status = mkdir( + const_cast(local_fpath.c_str()), + S_IRUSR | S_IWUSR | S_IXUSR); +#endif + if (status == -1) { + return Status( + Status::Code::INTERNAL, + "Failed to create local folder: " + local_fpath + + ", errno:" + strerror(errno)); + } + + // Add sub-directories and deeper files to contents + std::set subdir_contents; + RETURN_IF_ERROR(GetDirectoryContents(gcs_fpath, &subdir_contents)); + for (auto itr = subdir_contents.begin(); itr != subdir_contents.end(); + ++itr) { + contents.insert(JoinPath({gcs_fpath, *itr})); + } + } else { + // Create local copy of file + std::string file_bucket, file_object; + RETURN_IF_ERROR(ParsePath(gcs_fpath, &file_bucket, &file_object)); + + // Send a request to read the object + gcs::ObjectReadStream filestream = + client_->ReadObject(file_bucket, file_object); + if (!filestream) { + return Status( + Status::Code::INTERNAL, "Failed to get object at " + *iter + + " : " + + filestream.status().message()); + } + + std::string gcs_removed_path = (*iter).substr(path.size()); + std::string local_file_path = + JoinPath({(*localized)->Path(), gcs_removed_path}); + std::ofstream output_file(local_file_path.c_str(), std::ios::binary); + output_file << filestream.rdbuf(); + output_file.close(); + } + } + } + + return Status::Success; +} + +Status +GCSFileSystem::WriteTextFile( + const std::string& path, const std::string& contents) +{ + return Status( + Status::Code::UNSUPPORTED, + "Write text file operation not yet implemented " + path); +} + +Status +GCSFileSystem::WriteBinaryFile( + const std::string& path, const char* contents, const size_t content_len) +{ + return Status( + Status::Code::UNSUPPORTED, + "Write text file operation not yet implemented " + path); +} + +Status +GCSFileSystem::MakeDirectory(const std::string& dir, const bool recursive) +{ + return Status( + Status::Code::UNSUPPORTED, + "Make temporary directory operation not yet implemented"); +} + +Status +GCSFileSystem::MakeTemporaryDirectory(std::string* temp_dir) +{ + return Status( + Status::Code::UNSUPPORTED, + "Make temporary directory operation not yet implemented"); +} + +Status +GCSFileSystem::DeletePath(const std::string& path) +{ + return Status( + Status::Code::UNSUPPORTED, "Delete path operation not yet implemented"); +} + +#endif // TRITON_ENABLE_GCS + + +#ifdef TRITON_ENABLE_AZURE_STORAGE + +namespace as = azure::storage_lite; +const std::string AS_URL_PATTERN = "as://([^/]+)/([^/?]+)(?:/([^?]*))?(\\?.*)?"; + +struct ASCredential { + std::string account_str_; + std::string account_key_; + + ASCredential(); // from env var + ASCredential(triton::common::TritonJson::Value& cred_json); +}; + +ASCredential::ASCredential() +{ + const auto to_str = [](const char* s) -> std::string { + return (s != nullptr ? std::string(s) : ""); + }; + const char* account_str = std::getenv("AZURE_STORAGE_ACCOUNT"); + const char* account_key = std::getenv("AZURE_STORAGE_KEY"); + account_str_ = to_str(account_str); + account_key_ = to_str(account_key); +} + +ASCredential::ASCredential(triton::common::TritonJson::Value& cred_json) +{ + triton::common::TritonJson::Value account_str_json, account_key_json; + if (cred_json.Find("account_str", &account_str_json)) + account_str_json.AsString(&account_str_); + if (cred_json.Find("account_key", &account_key_json)) + account_key_json.AsString(&account_key_); +} + +class ASFileSystem : public FileSystem { + public: + ASFileSystem(const std::string& path, const ASCredential& as_cred); + Status CheckClient(); + // unify with S3 interface + Status CheckClient(const std::string& path) { return CheckClient(); } + + Status FileExists(const std::string& path, bool* exists) override; + Status IsDirectory(const std::string& path, bool* is_dir) override; + Status FileModificationTime( + const std::string& path, int64_t* mtime_ns) override; + Status GetDirectoryContents( + const std::string& path, std::set* contents) override; + Status GetDirectorySubdirs( + const std::string& path, std::set* subdirs) override; + Status GetDirectoryFiles( + const std::string& path, std::set* files) override; + Status ReadTextFile(const std::string& path, std::string* contents) override; + Status LocalizePath( + const std::string& path, + std::shared_ptr* localized) override; + Status WriteTextFile( + const std::string& path, const std::string& contents) override; + Status WriteBinaryFile( + const std::string& path, const char* contents, + const size_t content_len) override; + Status MakeDirectory(const std::string& dir, const bool recursive) override; + Status MakeTemporaryDirectory(std::string* temp_dir) override; + Status DeletePath(const std::string& path) override; + + private: + Status ParsePath( + const std::string& path, std::string* bucket, std::string* object); + std::shared_ptr client_; + + Status ListDirectory( + const std::string& path, const std::string& dir_path, + std::function< + Status(const as::list_blobs_segmented_item&, const std::string&)> + func); + + Status DownloadFolder( + const std::string& container, const std::string& path, + const std::string& dest); + re2::RE2 as_regex_; +}; + +Status +ASFileSystem::ParsePath( + const std::string& path, std::string* container, std::string* object) +{ + std::string host_name, query; + if (!RE2::FullMatch(path, as_regex_, &host_name, container, object, &query)) { + return Status( + Status::Code::INTERNAL, "Invalid azure storage path: " + path); + } + return Status::Success; +} + +ASFileSystem::ASFileSystem(const std::string& path, const ASCredential& as_cred) + : as_regex_(AS_URL_PATTERN) +{ + std::shared_ptr account = nullptr; + std::string host_name, container, blob_path, query; + if (RE2::FullMatch( + path, as_regex_, &host_name, &container, &blob_path, &query)) { + size_t pos = host_name.rfind(".blob.core.windows.net"); + std::string account_name; + if (as_cred.account_str_.empty()) { + if (pos != std::string::npos) { + account_name = host_name.substr(0, pos); + } else { + account_name = host_name; + } + } else { + account_name = as_cred.account_str_; + } + + std::shared_ptr cred; + if (!as_cred.account_key_.empty()) { + // Shared Key + cred = std::make_shared( + account_name, as_cred.account_key_); + } else { + cred = std::make_shared(); + } + account = std::make_shared( + account_name, cred, /* use_https */ true); + client_ = + std::make_shared(account, /*max_concurrency*/ 16); + } +} + +Status +ASFileSystem::CheckClient() +{ + if (client_ == nullptr) { + return Status( + Status::Code::INTERNAL, + "Unable to create Azure filesystem client. Check account credentials."); + } + return Status::Success; +} + + +Status +ASFileSystem::FileModificationTime(const std::string& path, int64_t* mtime_ns) +{ + as::blob_client_wrapper bc(client_); + std::string container, object_path; + RETURN_IF_ERROR(ParsePath(path, &container, &object_path)); + + auto blobProperty = bc.get_blob_property(container, object_path); + if (errno != 0) { + return Status( + Status::Code::INTERNAL, "Unable to get blob property for file at " + + path + ", errno:" + strerror(errno)); + } + + auto time = + std::chrono::system_clock::from_time_t(blobProperty.last_modified); + auto update_time = + std::chrono::time_point_cast(time) + .time_since_epoch() + .count(); + + *mtime_ns = update_time; + return Status::Success; +}; + +Status +ASFileSystem::ListDirectory( + const std::string& container, const std::string& dir_path, + std::function< + Status(const as::list_blobs_segmented_item&, const std::string&)> + func) +{ + as::blob_client_wrapper bc(client_); + + // Append a slash to make it easier to list contents + std::string full_dir = AppendSlash(dir_path); + auto blobs = bc.list_blobs_segmented(container, "/", "", full_dir); + if (errno != 0) { + return Status( + Status::Code::INTERNAL, "Failed to get contents of directory " + + dir_path + ", errno:" + strerror(errno)); + } + + for (auto&& item : blobs.blobs) { + std::string name = item.name; + int item_start = name.find(full_dir) + full_dir.size(); + int item_end = name.find("/", item_start); + // Let set take care of subdirectory contents + std::string subfile = name.substr(item_start, item_end - item_start); + auto status = func(item, subfile); + if (!status.IsOk()) { + return status; + } + } + return Status::Success; +} + +Status +ASFileSystem::GetDirectoryContents( + const std::string& path, std::set* contents) +{ + auto func = [&](const as::list_blobs_segmented_item& item, + const std::string& dir) { + contents->insert(dir); + return Status::Success; + }; + std::string container, dir_path; + RETURN_IF_ERROR(ParsePath(path, &container, &dir_path)); + return ListDirectory(container, dir_path, func); +} + +Status +ASFileSystem::GetDirectorySubdirs( + const std::string& path, std::set* subdirs) +{ + auto func = [&](const as::list_blobs_segmented_item& item, + const std::string& dir) { + if (item.is_directory) { + subdirs->insert(dir); + } + return Status::Success; + }; + std::string container, dir_path; + RETURN_IF_ERROR(ParsePath(path, &container, &dir_path)); + return ListDirectory(container, dir_path, func); +} + +Status +ASFileSystem::GetDirectoryFiles( + const std::string& path, std::set* files) +{ + auto func = [&](const as::list_blobs_segmented_item& item, + const std::string& file) { + if (!item.is_directory) { + files->insert(file); + } + return Status::Success; + }; + std::string container, dir_path; + RETURN_IF_ERROR(ParsePath(path, &container, &dir_path)); + return ListDirectory(container, dir_path, func); +} + +Status +ASFileSystem::IsDirectory(const std::string& path, bool* is_dir) +{ + *is_dir = false; + std::string container, object_path; + RETURN_IF_ERROR(ParsePath(path, &container, &object_path)); + + as::blob_client_wrapper bc(client_); + auto blobs = bc.list_blobs_segmented(container, "/", "", object_path, 1); + if (errno != 0) { + return Status( + Status::Code::INTERNAL, "Failed to check if directory at " + path + + ", errno:" + strerror(errno)); + } + *is_dir = blobs.blobs.size() > 0; + + return Status::Success; +}; + +Status +ASFileSystem::ReadTextFile(const std::string& path, std::string* contents) +{ + as::blob_client_wrapper bc(client_); + std::string container, object_path; + RETURN_IF_ERROR(ParsePath(path, &container, &object_path)); + using namespace azure::storage_lite; + std::ostringstream out_stream; + bc.download_blob_to_stream(container, object_path, 0, 0, out_stream); + if (errno != 0) { + return Status( + Status::Code::INTERNAL, "Failed to fetch file stream at " + path + + ", errno:" + strerror(errno)); + } + *contents = out_stream.str(); + + return Status::Success; +} + +Status +ASFileSystem::FileExists(const std::string& path, bool* exists) +{ + *exists = false; + + std::string container, object; + RETURN_IF_ERROR(ParsePath(path, &container, &object)); + as::blob_client_wrapper bc(client_); + auto blobs = bc.list_blobs_segmented(container, "/", "", object, 1); + if (errno != 0) { + return Status( + Status::Code::INTERNAL, "Failed to check if file exists at " + path + + ", errno:" + strerror(errno)); + } + if (blobs.blobs.size() > 0) { + *exists = true; + } + return Status::Success; +} + +Status +ASFileSystem::DownloadFolder( + const std::string& container, const std::string& path, + const std::string& dest) +{ + as::blob_client_wrapper bc(client_); + auto func = [&](const as::list_blobs_segmented_item& item, + const std::string& dir) { + auto local_path = JoinPath({dest, dir}); + auto blob_path = JoinPath({path, dir}); + if (item.is_directory) { + int status = mkdir( + const_cast(local_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); + if (status == -1) { + return Status( + Status::Code::INTERNAL, + "Failed to create local folder: " + local_path + + ", errno:" + strerror(errno)); + } + auto ret = DownloadFolder(container, blob_path, local_path); + if (!ret.IsOk()) { + return ret; + } + } else { + time_t last_modified; + bc.download_blob_to_file(container, blob_path, local_path, last_modified); + if (errno != 0) { + return Status( + Status::Code::INTERNAL, "Failed to download file at " + blob_path + + ", errno:" + strerror(errno)); + } + } + return Status::Success; + }; + return ListDirectory(container, path, func); +} + +Status +ASFileSystem::LocalizePath( + const std::string& path, std::shared_ptr* localized) +{ + bool exists; + RETURN_IF_ERROR(FileExists(path, &exists)); + if (!exists) { + return Status( + Status::Code::INTERNAL, "directory or file does not exist at " + path); + } + + bool is_dir; + RETURN_IF_ERROR(IsDirectory(path, &is_dir)); + if (!is_dir) { + return Status( + Status::Code::UNSUPPORTED, + "AS file localization not yet implemented " + path); + } + + std::string folder_template = "/tmp/folderXXXXXX"; + char* tmp_folder = mkdtemp(const_cast(folder_template.c_str())); + if (tmp_folder == nullptr) { + return Status( + Status::Code::INTERNAL, + "Failed to create local temp folder: " + folder_template + + ", errno:" + strerror(errno)); + } + localized->reset(new LocalizedPath(path, tmp_folder)); + + std::string dest(folder_template); + + as::blob_client_wrapper bc(client_); + + std::string container, object; + RETURN_IF_ERROR(ParsePath(path, &container, &object)); + return DownloadFolder(container, object, dest); +} + +Status +ASFileSystem::WriteTextFile( + const std::string& path, const std::string& contents) +{ + std::stringstream ss(contents); + std::istream is(ss.rdbuf()); + std::string container, object; + RETURN_IF_ERROR(ParsePath(path, &container, &object)); + std::vector> metadata; + auto ret = + client_->upload_block_blob_from_stream(container, object, is, metadata) + .get(); + if (!ret.success()) { + return Status( + Status::Code::INTERNAL, + "Failed to upload blob, Error: " + ret.error().code + ", " + + ret.error().code_name); + } + return Status::Success; +} + +Status +ASFileSystem::WriteBinaryFile( + const std::string& path, const char* contents, const size_t content_len) +{ + return Status( + Status::Code::UNSUPPORTED, + "Write text file operation not yet implemented " + path); +} + +Status +ASFileSystem::MakeDirectory(const std::string& dir, const bool recursive) +{ + return Status( + Status::Code::UNSUPPORTED, + "Make directory operation not yet implemented"); +} + +Status +ASFileSystem::MakeTemporaryDirectory(std::string* temp_dir) +{ + return Status( + Status::Code::UNSUPPORTED, + "Make temporary directory operation not yet implemented"); +} + +Status +ASFileSystem::DeletePath(const std::string& path) +{ + return Status( + Status::Code::UNSUPPORTED, "Delete path operation not yet implemented"); +} + +#endif // TRITON_ENABLE_AZURE_STORAGE + + +#ifdef TRITON_ENABLE_S3 + +namespace s3 = Aws::S3; + +struct S3Credential { + std::string secret_key_; + std::string key_id_; + std::string region_; + std::string session_token_; + std::string profile_name_; + + S3Credential(); // from env var + S3Credential(triton::common::TritonJson::Value& cred_json); +}; + +S3Credential::S3Credential() +{ + const auto to_str = [](const char* s) -> std::string { + return (s != nullptr ? std::string(s) : ""); + }; + const char* secret_key = std::getenv("AWS_SECRET_ACCESS_KEY"); + const char* key_id = std::getenv("AWS_ACCESS_KEY_ID"); + const char* region = std::getenv("AWS_DEFAULT_REGION"); + const char* session_token = std::getenv("AWS_SESSION_TOKEN"); + const char* profile = std::getenv("AWS_PROFILE"); + secret_key_ = to_str(secret_key); + key_id_ = to_str(key_id); + region_ = to_str(region); + session_token_ = to_str(session_token); + profile_name_ = to_str(profile); +} + +S3Credential::S3Credential(triton::common::TritonJson::Value& cred_json) +{ + triton::common::TritonJson::Value secret_key_json, key_id_json, region_json, + session_token_json, profile_json; + if (cred_json.Find("secret_key", &secret_key_json)) + secret_key_json.AsString(&secret_key_); + if (cred_json.Find("key_id", &key_id_json)) + key_id_json.AsString(&key_id_); + if (cred_json.Find("region", ®ion_json)) + region_json.AsString(®ion_); + if (cred_json.Find("session_token", &session_token_json)) + session_token_json.AsString(&session_token_); + if (cred_json.Find("profile", &profile_json)) + profile_json.AsString(&profile_name_); +} + +class S3FileSystem : public FileSystem { + public: + S3FileSystem(const std::string& s3_path, const S3Credential& s3_cred); + Status CheckClient(const std::string& s3_path); + + Status FileExists(const std::string& path, bool* exists) override; + Status IsDirectory(const std::string& path, bool* is_dir) override; + Status FileModificationTime( + const std::string& path, int64_t* mtime_ns) override; + Status GetDirectoryContents( + const std::string& path, std::set* contents) override; + Status GetDirectorySubdirs( + const std::string& path, std::set* subdirs) override; + Status GetDirectoryFiles( + const std::string& path, std::set* files) override; + Status ReadTextFile(const std::string& path, std::string* contents) override; + Status LocalizePath( + const std::string& path, + std::shared_ptr* localized) override; + Status WriteTextFile( + const std::string& path, const std::string& contents) override; + Status WriteBinaryFile( + const std::string& path, const char* contents, + const size_t content_len) override; + Status MakeDirectory(const std::string& dir, const bool recursive) override; + Status MakeTemporaryDirectory(std::string* temp_dir) override; + Status DeletePath(const std::string& path) override; + + private: + Status ParsePath( + const std::string& path, std::string* bucket, std::string* object); + Status CleanPath(const std::string& s3_path, std::string* clean_path); + std::unique_ptr client_; // init after Aws::InitAPI is called + re2::RE2 s3_regex_; +}; + +Status +S3FileSystem::ParsePath( + const std::string& path, std::string* bucket, std::string* object) +{ + // Cleanup extra slashes + std::string clean_path; + RETURN_IF_ERROR(CleanPath(path, &clean_path)); + + // Get the bucket name and the object path. Return error if path is malformed + std::string protocol, host_name, host_port; + if (!RE2::FullMatch( + clean_path, s3_regex_, &protocol, &host_name, &host_port, bucket, + object)) { + int bucket_start = clean_path.find("s3://") + strlen("s3://"); + int bucket_end = clean_path.find("/", bucket_start); + + // If there isn't a slash, the address has only the bucket + if (bucket_end > bucket_start) { + *bucket = clean_path.substr(bucket_start, bucket_end - bucket_start); + *object = clean_path.substr(bucket_end + 1); + } else { + *bucket = clean_path.substr(bucket_start); + *object = ""; + } + } else { + // Erase leading '/' that is left behind in object name + if ((*object)[0] == '/') { + object->erase(0, 1); + } + } + + if (bucket->empty()) { + return Status( + Status::Code::INTERNAL, "No bucket name found in path: " + path); + } + + return Status::Success; +} + +Status +S3FileSystem::CleanPath(const std::string& s3_path, std::string* clean_path) +{ + // Must handle paths with s3 prefix + size_t start = s3_path.find("s3://"); + std::string path = ""; + if (start != std::string::npos) { + path = s3_path.substr(start + strlen("s3://")); + *clean_path = "s3://"; + } else { + path = s3_path; + *clean_path = ""; + } + + // Must handle paths with https:// or http:// prefix + size_t https_start = path.find("https://"); + if (https_start != std::string::npos) { + path = path.substr(https_start + strlen("https://")); + *clean_path += "https://"; + } else { + size_t http_start = path.find("http://"); + if (http_start != std::string::npos) { + path = path.substr(http_start + strlen("http://")); + *clean_path += "http://"; + } + } + + // Remove trailing slashes + size_t rtrim_length = path.find_last_not_of('/'); + if (rtrim_length == std::string::npos) { + return Status( + Status::Code::INVALID_ARG, "Invalid bucket name: '" + path + "'"); + } + + // Remove leading slashes + size_t ltrim_length = path.find_first_not_of('/'); + if (ltrim_length == std::string::npos) { + return Status( + Status::Code::INVALID_ARG, "Invalid bucket name: '" + path + "'"); + } + + // Remove extra internal slashes + std::string true_path = path.substr(ltrim_length, rtrim_length + 1); + std::vector slash_locations; + bool previous_slash = false; + for (size_t i = 0; i < true_path.size(); i++) { + if (true_path[i] == '/') { + if (!previous_slash) { + *clean_path += true_path[i]; + } + previous_slash = true; + } else { + *clean_path += true_path[i]; + previous_slash = false; + } + } + + return Status::Success; +} + +S3FileSystem::S3FileSystem( + const std::string& s3_path, const S3Credential& s3_cred) + : s3_regex_( + "s3://(http://|https://|)([0-9a-zA-Z\\-.]+):([0-9]+)/" + "([0-9a-z.\\-]+)(((/[0-9a-zA-Z.\\-_]+)*)?)") +{ + // init aws api if not already + Aws::SDKOptions options; + static std::once_flag onceFlag; + std::call_once(onceFlag, [&options] { Aws::InitAPI(options); }); + + Aws::Client::ClientConfiguration config; + Aws::Auth::AWSCredentials credentials; + + // check vars for S3 credentials -> aws profile -> default + if (!s3_cred.secret_key_.empty() && !s3_cred.key_id_.empty()) { + credentials.SetAWSAccessKeyId(s3_cred.key_id_.c_str()); + credentials.SetAWSSecretKey(s3_cred.secret_key_.c_str()); + if (!s3_cred.session_token_.empty()) { + credentials.SetSessionToken(s3_cred.session_token_.c_str()); + } + config = Aws::Client::ClientConfiguration(); + if (!s3_cred.region_.empty()) { + config.region = s3_cred.region_.c_str(); + } + } else if (!s3_cred.profile_name_.empty()) { + config = Aws::Client::ClientConfiguration(s3_cred.profile_name_.c_str()); + } else { + config = Aws::Client::ClientConfiguration("default"); + } + + // Cleanup extra slashes + std::string clean_path; + LOG_STATUS_ERROR(CleanPath(s3_path, &clean_path), "failed to parse S3 path"); + + std::string protocol, host_name, host_port, bucket, object; + if (RE2::FullMatch( + clean_path, s3_regex_, &protocol, &host_name, &host_port, &bucket, + &object)) { + config.endpointOverride = Aws::String(host_name + ":" + host_port); + if (protocol == "https://") { + config.scheme = Aws::Http::Scheme::HTTPS; + } else { + config.scheme = Aws::Http::Scheme::HTTP; + } + } + + if (!s3_cred.secret_key_.empty() && !s3_cred.key_id_.empty()) { + client_ = std::make_unique( + credentials, config, + Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, + /*useVirtualAdressing*/ false); + } else { + client_ = std::make_unique( + config, Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, + /*useVirtualAdressing*/ false); + } +} + +Status +S3FileSystem::CheckClient(const std::string& s3_path) +{ + std::string bucket, object_path; + RETURN_IF_ERROR(ParsePath(s3_path, &bucket, &object_path)); + // check if can connect to the bucket + s3::Model::HeadBucketRequest head_request; + head_request.WithBucket(bucket.c_str()); + if (!client_->HeadBucket(head_request).IsSuccess()) { + return Status( + Status::Code::INTERNAL, + "Unable to create S3 filesystem client. Check account credentials."); + } + return Status::Success; +} + +Status +S3FileSystem::FileExists(const std::string& path, bool* exists) +{ + *exists = false; + + // S3 doesn't make objects for directories, so it could still be a directory + bool is_dir; + RETURN_IF_ERROR(IsDirectory(path, &is_dir)); + if (is_dir) { + *exists = is_dir; + return Status::Success; + } + + std::string bucket, object; + RETURN_IF_ERROR(ParsePath(path, &bucket, &object)); + + // Construct request for object metadata + s3::Model::HeadObjectRequest head_request; + head_request.SetBucket(bucket.c_str()); + head_request.SetKey(object.c_str()); + + auto head_object_outcome = client_->HeadObject(head_request); + if (!head_object_outcome.IsSuccess()) { + if (head_object_outcome.GetError().GetErrorType() != + s3::S3Errors::RESOURCE_NOT_FOUND) { + return Status( + Status::Code::INTERNAL, + "Could not get MetaData for object at " + path + + " due to exception: " + + head_object_outcome.GetError().GetExceptionName() + + ", error message: " + + head_object_outcome.GetError().GetMessage()); + } + } else { + *exists = true; + } + + return Status::Success; +} + +Status +S3FileSystem::IsDirectory(const std::string& path, bool* is_dir) +{ + *is_dir = false; + std::string bucket, object_path; + RETURN_IF_ERROR(ParsePath(path, &bucket, &object_path)); + + // Check if the bucket exists + s3::Model::HeadBucketRequest head_request; + head_request.WithBucket(bucket.c_str()); + + auto head_bucket_outcome = client_->HeadBucket(head_request); + if (!head_bucket_outcome.IsSuccess()) { + return Status( + Status::Code::INTERNAL, + "Could not get MetaData for bucket with name " + bucket + + " due to exception: " + + head_bucket_outcome.GetError().GetExceptionName() + + ", error message: " + head_bucket_outcome.GetError().GetMessage()); + } + + // Root case - bucket exists and object path is empty + if (object_path.empty()) { + *is_dir = true; + return Status::Success; + } + + // List the objects in the bucket + s3::Model::ListObjectsRequest list_objects_request; + list_objects_request.SetBucket(bucket.c_str()); + list_objects_request.SetPrefix(AppendSlash(object_path).c_str()); + auto list_objects_outcome = client_->ListObjects(list_objects_request); + + if (list_objects_outcome.IsSuccess()) { + *is_dir = !list_objects_outcome.GetResult().GetContents().empty(); + } else { + return Status( + Status::Code::INTERNAL, + "Failed to list objects with prefix " + path + " due to exception: " + + list_objects_outcome.GetError().GetExceptionName() + + ", error message: " + list_objects_outcome.GetError().GetMessage()); + } + return Status::Success; +} + +Status +S3FileSystem::FileModificationTime(const std::string& path, int64_t* mtime_ns) +{ + // We don't need to worry about the case when this is a directory + bool is_dir; + RETURN_IF_ERROR(IsDirectory(path, &is_dir)); + if (is_dir) { + *mtime_ns = 0; + return Status::Success; + } + + std::string bucket, object; + RETURN_IF_ERROR(ParsePath(path, &bucket, &object)); + + // Send a request for the objects metadata + s3::Model::HeadObjectRequest head_request; + head_request.SetBucket(bucket.c_str()); + head_request.SetKey(object.c_str()); + + // If request succeeds, copy over the modification time + auto head_object_outcome = client_->HeadObject(head_request); + if (head_object_outcome.IsSuccess()) { + *mtime_ns = head_object_outcome.GetResult().GetLastModified().Millis() * + NANOS_PER_MILLIS; + } else { + return Status( + Status::Code::INTERNAL, + "Failed to get modification time for object at " + path + + " due to exception: " + + head_object_outcome.GetError().GetExceptionName() + + ", error message: " + head_object_outcome.GetError().GetMessage()); + } + return Status::Success; +} + +Status +S3FileSystem::GetDirectoryContents( + const std::string& path, std::set* contents) +{ + // Parse bucket and dir_path + std::string bucket, dir_path, full_dir; + RETURN_IF_ERROR(ParsePath(path, &bucket, &dir_path)); + std::string true_path = "s3://" + bucket + '/' + dir_path; + + // Capture the full path to facilitate content listing + full_dir = AppendSlash(dir_path); + + // Issue request for objects with prefix + s3::Model::ListObjectsRequest objects_request; + objects_request.SetBucket(bucket.c_str()); + objects_request.SetPrefix(full_dir.c_str()); + auto list_objects_outcome = client_->ListObjects(objects_request); + + if (list_objects_outcome.IsSuccess()) { + Aws::Vector object_list = + list_objects_outcome.GetResult().GetContents(); + for (auto const& s3_object : object_list) { + // In the case of empty directories, the directory itself will appear here + if (s3_object.GetKey().c_str() == full_dir) { + continue; + } + + // We have to make sure that subdirectory contents do not appear here + std::string name(s3_object.GetKey().c_str()); + int item_start = name.find(full_dir) + full_dir.size(); + // S3 response prepends parent directory name + int item_end = name.find("/", item_start); + + // Let set take care of subdirectory contents + std::string item = name.substr(item_start, item_end - item_start); + contents->insert(item); + } + } else { + return Status( + Status::Code::INTERNAL, + "Could not list contents of directory at " + true_path + + " due to exception: " + + list_objects_outcome.GetError().GetExceptionName() + + ", error message: " + list_objects_outcome.GetError().GetMessage()); + } + return Status::Success; +} + +Status +S3FileSystem::GetDirectorySubdirs( + const std::string& path, std::set* subdirs) +{ + // Parse bucket and dir_path + std::string bucket, dir_path; + RETURN_IF_ERROR(ParsePath(path, &bucket, &dir_path)); + std::string true_path = "s3://" + bucket + '/' + dir_path; + + RETURN_IF_ERROR(GetDirectoryContents(true_path, subdirs)); + + // Erase non-directory entries... + for (auto iter = subdirs->begin(); iter != subdirs->end();) { + bool is_dir; + RETURN_IF_ERROR(IsDirectory(JoinPath({true_path, *iter}), &is_dir)); + if (!is_dir) { + iter = subdirs->erase(iter); + } else { + ++iter; + } + } + + return Status::Success; +} +Status +S3FileSystem::GetDirectoryFiles( + const std::string& path, std::set* files) +{ + // Parse bucket and dir_path + std::string bucket, dir_path; + RETURN_IF_ERROR(ParsePath(path, &bucket, &dir_path)); + std::string true_path = "s3://" + bucket + '/' + dir_path; + RETURN_IF_ERROR(GetDirectoryContents(true_path, files)); + + // Erase directory entries... + for (auto iter = files->begin(); iter != files->end();) { + bool is_dir; + RETURN_IF_ERROR(IsDirectory(JoinPath({true_path, *iter}), &is_dir)); + if (is_dir) { + iter = files->erase(iter); + } else { + ++iter; + } + } + + return Status::Success; +} + +Status +S3FileSystem::ReadTextFile(const std::string& path, std::string* contents) +{ + bool exists; + RETURN_IF_ERROR(FileExists(path, &exists)); + + if (!exists) { + return Status(Status::Code::INTERNAL, "File does not exist at " + path); + } + + std::string bucket, object; + RETURN_IF_ERROR(ParsePath(path, &bucket, &object)); + + // Send a request for the objects metadata + s3::Model::GetObjectRequest object_request; + object_request.SetBucket(bucket.c_str()); + object_request.SetKey(object.c_str()); + + auto get_object_outcome = client_->GetObject(object_request); + if (get_object_outcome.IsSuccess()) { + auto& object_result = get_object_outcome.GetResultWithOwnership().GetBody(); + + std::string data = ""; + char c; + while (object_result.get(c)) { + data += c; + } + + *contents = data; + } else { + return Status( + Status::Code::INTERNAL, + "Failed to get object at " + path + " due to exception: " + + get_object_outcome.GetError().GetExceptionName() + + ", error message: " + get_object_outcome.GetError().GetMessage()); + } + + return Status::Success; +} + +Status +S3FileSystem::LocalizePath( + const std::string& path, std::shared_ptr* localized) +{ + // Check if the directory or file exists + bool exists; + RETURN_IF_ERROR(FileExists(path, &exists)); + if (!exists) { + return Status( + Status::Code::INTERNAL, "directory or file does not exist at " + path); + } + + // Cleanup extra slashes + std::string clean_path; + RETURN_IF_ERROR(CleanPath(path, &clean_path)); + + // Remove protocol and host name and port + std::string effective_path, protocol, host_name, host_port, bucket, object; + if (RE2::FullMatch( + clean_path, s3_regex_, &protocol, &host_name, &host_port, &bucket, + &object)) { + effective_path = "s3://" + bucket + object; + } else { + effective_path = path; + } + + // Create temporary directory + std::string tmp_folder; + RETURN_IF_ERROR( + triton::core::MakeTemporaryDirectory(FileSystemType::LOCAL, &tmp_folder)); + + // Specify contents to be downloaded + std::set contents; + bool is_dir; + RETURN_IF_ERROR(IsDirectory(path, &is_dir)); + if (is_dir) { + // Set localized path + localized->reset(new LocalizedPath(effective_path, tmp_folder)); + // Specify the entire directory to be downloaded + std::set filenames; + RETURN_IF_ERROR(GetDirectoryContents(effective_path, &filenames)); + for (auto itr = filenames.begin(); itr != filenames.end(); ++itr) { + contents.insert(JoinPath({effective_path, *itr})); + } + } else { + // Set localized path + std::string filename = + effective_path.substr(effective_path.find_last_of('/') + 1); + localized->reset( + new LocalizedPath(effective_path, JoinPath({tmp_folder, filename}))); + // Specify only the file to be downloaded + contents.insert(effective_path); + } + + // Download all specified contents and nested contents + while (contents.size() != 0) { + std::set tmp_contents = contents; + contents.clear(); + for (auto iter = tmp_contents.begin(); iter != tmp_contents.end(); ++iter) { + std::string s3_fpath = *iter; + std::string s3_removed_path = s3_fpath.substr(effective_path.size()); + std::string local_fpath = + s3_removed_path.empty() + ? (*localized)->Path() + : JoinPath({(*localized)->Path(), s3_removed_path}); + bool is_subdir; + RETURN_IF_ERROR(IsDirectory(s3_fpath, &is_subdir)); + if (is_subdir) { + // Create local mirror of sub-directories +#ifdef _WIN32 + int status = mkdir(const_cast(local_fpath.c_str())); +#else + int status = mkdir( + const_cast(local_fpath.c_str()), + S_IRUSR | S_IWUSR | S_IXUSR); +#endif + if (status == -1) { + return Status( + Status::Code::INTERNAL, + "Failed to create local folder: " + local_fpath + + ", errno:" + strerror(errno)); + } + + // Add sub-directories and deeper files to contents + std::set subdir_contents; + RETURN_IF_ERROR(GetDirectoryContents(s3_fpath, &subdir_contents)); + for (auto itr = subdir_contents.begin(); itr != subdir_contents.end(); + ++itr) { + contents.insert(JoinPath({s3_fpath, *itr})); + } + } else { + // Create local copy of file + std::string file_bucket, file_object; + RETURN_IF_ERROR(ParsePath(s3_fpath, &file_bucket, &file_object)); + + s3::Model::GetObjectRequest object_request; + object_request.SetBucket(file_bucket.c_str()); + object_request.SetKey(file_object.c_str()); + + auto get_object_outcome = client_->GetObject(object_request); + if (get_object_outcome.IsSuccess()) { + auto& retrieved_file = + get_object_outcome.GetResultWithOwnership().GetBody(); + std::ofstream output_file(local_fpath.c_str(), std::ios::binary); + output_file << retrieved_file.rdbuf(); + output_file.close(); + } else { + return Status( + Status::Code::INTERNAL, + "Failed to get object at " + s3_fpath + " due to exception: " + + get_object_outcome.GetError().GetExceptionName() + + ", error message: " + + get_object_outcome.GetError().GetMessage()); + } + } + } + } + + return Status::Success; +} + +Status +S3FileSystem::WriteTextFile( + const std::string& path, const std::string& contents) +{ + return Status( + Status::Code::UNSUPPORTED, + "Write text file operation not yet implemented " + path); +} + +Status +S3FileSystem::WriteBinaryFile( + const std::string& path, const char* contents, const size_t content_len) +{ + return Status( + Status::Code::UNSUPPORTED, + "Write text file operation not yet implemented " + path); +} + +Status +S3FileSystem::MakeDirectory(const std::string& dir, const bool recursive) +{ + return Status( + Status::Code::UNSUPPORTED, + "Make directory operation not yet implemented"); +} + +Status +S3FileSystem::MakeTemporaryDirectory(std::string* temp_dir) +{ + return Status( + Status::Code::UNSUPPORTED, + "Make temporary directory operation not yet implemented"); +} + +Status +S3FileSystem::DeletePath(const std::string& path) +{ + return Status( + Status::Code::UNSUPPORTED, "Delete path operation not yet implemented"); +} + + +#endif // TRITON_ENABLE_S3 + + +class FileSystemManager { + public: + Status GetFileSystem( + const std::string& path, std::shared_ptr& file_system); + Status GetFileSystem( + FileSystemType type, std::shared_ptr& file_system); + FileSystemManager(); + + private: + template + Status GetFileSystem( + const std::string& path, CacheType& cache, + std::shared_ptr& file_system); + template + Status ReturnErrorOrReload( + const Status& load_status, const Status& error_status, + const std::string& path, CacheType& cache, + std::shared_ptr& file_system); + Status LoadCredentials(bool flush_cache = false); + template + static void LoadCredential( + triton::common::TritonJson::Value& creds_json, const char* fs_type, + CacheType& cache); + template + static void SortCache( + std::vector>>& + cache); + template + static Status GetLongestMatchingNameIndex( + const std::vector>>& cache, + const std::string& path, size_t& idx); + + std::shared_ptr local_fs_; + std::mutex mu_; // protect concurrent access into variables + bool is_cached_; // if name and credential is cached, lazy load file system + // cloud credential cache should be sorted in descending name length order + // [(name_long, credential, file_system), (name, ...)] +#ifdef TRITON_ENABLE_GCS + std::vector< + std::tuple>> + gs_cache_; +#endif // TRITON_ENABLE_GCS +#ifdef TRITON_ENABLE_S3 + std::vector< + std::tuple>> + s3_cache_; +#endif // TRITON_ENABLE_S3 +#ifdef TRITON_ENABLE_AZURE_STORAGE + std::vector< + std::tuple>> + as_cache_; +#endif // TRITON_ENABLE_AZURE_STORAGE +}; + +FileSystemManager::FileSystemManager() + : local_fs_(new LocalFileSystem()), is_cached_(false) +{ +} + +Status +FileSystemManager::GetFileSystem( + const std::string& path, std::shared_ptr& file_system) +{ + // Check if this is a GCS path (gs://$BUCKET_NAME) + if (!path.empty() && !path.rfind("gs://", 0)) { +#ifndef TRITON_ENABLE_GCS + return Status( + Status::Code::INTERNAL, + "gs:// file-system not supported. To enable, build with " + "-DTRITON_ENABLE_GCS=ON."); +#else + return GetFileSystem< + std::vector>>, + GCSCredential, GCSFileSystem>(path, gs_cache_, file_system); +#endif // TRITON_ENABLE_GCS + } + + // Check if this is an S3 path (s3://$BUCKET_NAME) + if (!path.empty() && !path.rfind("s3://", 0)) { +#ifndef TRITON_ENABLE_S3 + return Status( + Status::Code::INTERNAL, + "s3:// file-system not supported. To enable, build with " + "-DTRITON_ENABLE_S3=ON."); +#else + return GetFileSystem< + std::vector>>, + S3Credential, S3FileSystem>(path, s3_cache_, file_system); +#endif // TRITON_ENABLE_S3 + } + + // Check if this is an Azure Storage path + if (!path.empty() && !path.rfind("as://", 0)) { +#ifndef TRITON_ENABLE_AZURE_STORAGE + return Status( + Status::Code::INTERNAL, + "as:// file-system not supported. To enable, build with " + "-DTRITON_ENABLE_AZURE_STORAGE=ON."); +#else + return GetFileSystem< + std::vector>>, + ASCredential, ASFileSystem>(path, as_cache_, file_system); +#endif // TRITON_ENABLE_AZURE_STORAGE + } + + // Assume path is for local filesystem + file_system = local_fs_; + return Status::Success; +} + +Status +FileSystemManager::GetFileSystem( + FileSystemType type, std::shared_ptr& file_system) +{ + // only LOCAL and GCS are not path-dependent and can be accessed by type + switch (type) { + case FileSystemType::LOCAL: + return GetFileSystem("", file_system); + case FileSystemType::GCS: + return GetFileSystem("gs://", file_system); + case FileSystemType::S3: + return Status( + Status::Code::UNSUPPORTED, + "S3 filesystem cannot be accessed by type"); + case FileSystemType::AS: + return Status( + Status::Code::UNSUPPORTED, + "AS filesystem cannot be accessed by type"); + default: + return Status(Status::Code::UNSUPPORTED, "Unsupported filesystem type"); + } +} + +template +Status +FileSystemManager::GetFileSystem( + const std::string& path, CacheType& cache, + std::shared_ptr& file_system) +{ + const Status& cred_status = LoadCredentials(); + if (cred_status.IsOk() || + cred_status.StatusCode() == Status::Code::ALREADY_EXISTS) { + // Find credential + size_t idx; + const Status& match_status = GetLongestMatchingNameIndex(cache, path, idx); + if (!match_status.IsOk()) { + return ReturnErrorOrReload( + cred_status, match_status, path, cache, file_system); + } + // Find or lazy load file system + std::shared_ptr fs = std::get<2>(cache[idx]); + if (fs == nullptr) { + std::string cred_name = std::get<0>(cache[idx]); + CredentialType cred = std::get<1>(cache[idx]); + fs = std::make_shared(path, cred); + cache[idx] = std::make_tuple(cred_name, cred, fs); + } + // Check client + const Status& client_status = fs->CheckClient(path); + if (!client_status.IsOk()) { + return ReturnErrorOrReload( + cred_status, client_status, path, cache, file_system); + } + // Return client + file_system = fs; + return Status::Success; + } + return cred_status; +} + +template +Status +FileSystemManager::ReturnErrorOrReload( + const Status& load_status, const Status& error_status, + const std::string& path, CacheType& cache, + std::shared_ptr& file_system) +{ + if (load_status.StatusCode() == Status::Code::ALREADY_EXISTS) { + return error_status; + } + LoadCredentials(true); // flush cache + return GetFileSystem( + path, cache, file_system); +} + +// return status meaning: +// - SUCCESS, "" -> loaded credential from file +// - ALREADY_EXISTS, "Cached" -> credential already loaded +Status +FileSystemManager::LoadCredentials(bool flush_cache) +{ + // prevent concurrent access into class variables + std::lock_guard lock(mu_); + + // check if credential is already cached + if (is_cached_ && !flush_cache) { + return Status(Status::Code::ALREADY_EXISTS, "Cached"); + } + + const char* file_path_c_str = std::getenv("TRITON_CLOUD_CREDENTIAL_PATH"); + if (file_path_c_str != nullptr) { + // Load from credential file + std::string file_path = std::string(file_path_c_str); + LOG_VERBOSE(1) << "Reading cloud credential from " << file_path; + + triton::common::TritonJson::Value creds_json; + std::string cred_file_content; + RETURN_IF_ERROR(local_fs_->ReadTextFile(file_path, &cred_file_content)); + RETURN_IF_ERROR(creds_json.Parse(cred_file_content)); + +#ifdef TRITON_ENABLE_GCS + // load GCS credentials + LoadCredential< + std::vector>>, + GCSCredential, GCSFileSystem>(creds_json, "gs", gs_cache_); +#endif // TRITON_ENABLE_GCS +#ifdef TRITON_ENABLE_S3 + // load S3 credentials + LoadCredential< + std::vector>>, + S3Credential, S3FileSystem>(creds_json, "s3", s3_cache_); +#endif // TRITON_ENABLE_S3 +#ifdef TRITON_ENABLE_AZURE_STORAGE + // load AS credentials + LoadCredential< + std::vector>>, + ASCredential, ASFileSystem>(creds_json, "as", as_cache_); +#endif // TRITON_ENABLE_AZURE_STORAGE + } else { + // Load from environment variables + LOG_VERBOSE(1) << "TRITON_CLOUD_CREDENTIAL_PATH environment variable is " + "not set, reading from environment variables"; + +#ifdef TRITON_ENABLE_GCS + // load GCS credentials + gs_cache_.clear(); + gs_cache_.push_back( + std::make_tuple("", GCSCredential(), std::shared_ptr())); +#endif // TRITON_ENABLE_GCS + +#ifdef TRITON_ENABLE_S3 + // load S3 credentials + s3_cache_.clear(); + s3_cache_.push_back( + std::make_tuple("", S3Credential(), std::shared_ptr())); +#endif // TRITON_ENABLE_S3 + +#ifdef TRITON_ENABLE_AZURE_STORAGE + // load AS credentials + as_cache_.clear(); + as_cache_.push_back( + std::make_tuple("", ASCredential(), std::shared_ptr())); +#endif // TRITON_ENABLE_AZURE_STORAGE + } + + is_cached_ = true; + return Status::Success; +} + +template +void +FileSystemManager::LoadCredential( + triton::common::TritonJson::Value& creds_json, const char* fs_type, + CacheType& cache) +{ + cache.clear(); + triton::common::TritonJson::Value creds_fs_json; + if (creds_json.Find(fs_type, &creds_fs_json)) { + std::vector cred_names; + creds_fs_json.Members(&cred_names); + for (size_t i = 0; i < cred_names.size(); i++) { + std::string cred_name = cred_names[i]; + triton::common::TritonJson::Value cred_json; + creds_fs_json.Find(cred_name.c_str(), &cred_json); + cache.push_back(std::make_tuple( + cred_name, CredentialType(cred_json), + std::shared_ptr())); + } + SortCache(cache); + } +} + +template +void +FileSystemManager::SortCache( + std::vector>>& cache) +{ + std::sort( + cache.begin(), cache.end(), + [](std::tuple< + std::string, CredentialType, std::shared_ptr> + a, + std::tuple< + std::string, CredentialType, std::shared_ptr> + b) { return std::get<0>(a).size() >= std::get<0>(b).size(); }); +} + +template +Status +FileSystemManager::GetLongestMatchingNameIndex( + const std::vector>>& cache, + const std::string& path, size_t& idx) +{ + for (size_t i = 0; i < cache.size(); i++) { + if (!path.rfind(std::get<0>(cache[i]), 0)) { + idx = i; + LOG_VERBOSE(1) << "Using credential " + std::get<0>(cache[i]) + + " for path " + path; + return Status::Success; + } + } + return Status( + Status::Code::NOT_FOUND, "Cannot match credential for path " + path); +} + +static FileSystemManager fsm_; + +} // namespace + +// FIXME: Windows support '/'? If so, the below doesn't need to change +bool +IsAbsolutePath(const std::string& path) +{ + return !path.empty() && (path[0] == '/'); +} + +std::string +JoinPath(std::initializer_list segments) +{ + std::string joined; + + for (const auto& seg : segments) { + if (joined.empty()) { + joined = seg; + } else if (IsAbsolutePath(seg)) { + if (joined[joined.size() - 1] == '/') { + joined.append(seg.substr(1)); + } else { + joined.append(seg); + } + } else { // !IsAbsolutePath(seg) + if (joined[joined.size() - 1] != '/') { + joined.append("/"); + } + joined.append(seg); + } + } + + return joined; +} + +std::string +BaseName(const std::string& path) +{ + if (path.empty()) { + return path; + } + + size_t last = path.size() - 1; + while ((last > 0) && (path[last] == '/')) { + last -= 1; + } + + if (path[last] == '/') { + return std::string(); + } + + const size_t idx = path.find_last_of("/", last); + if (idx == std::string::npos) { + return path.substr(0, last + 1); + } + + return path.substr(idx + 1, last - idx); +} + +std::string +DirName(const std::string& path) +{ + if (path.empty()) { + return path; + } + + size_t last = path.size() - 1; + while ((last > 0) && (path[last] == '/')) { + last -= 1; + } + + if (path[last] == '/') { + return std::string("/"); + } + + const size_t idx = path.find_last_of("/", last); + if (idx == std::string::npos) { + return std::string("."); + } + if (idx == 0) { + return std::string("/"); + } + + return path.substr(0, idx); +} + +Status +FileExists(const std::string& path, bool* exists) +{ + std::shared_ptr fs; + RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs)); + return fs->FileExists(path, exists); +} + +Status +IsDirectory(const std::string& path, bool* is_dir) +{ + std::shared_ptr fs; + RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs)); + return fs->IsDirectory(path, is_dir); +} + +Status +FileModificationTime(const std::string& path, int64_t* mtime_ns) +{ + std::shared_ptr fs; + RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs)); + return fs->FileModificationTime(path, mtime_ns); +} + +Status +GetDirectoryContents(const std::string& path, std::set* contents) +{ + std::shared_ptr fs; + RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs)); + return fs->GetDirectoryContents(path, contents); +} + +Status +GetDirectorySubdirs(const std::string& path, std::set* subdirs) +{ + std::shared_ptr fs; + RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs)); + return fs->GetDirectorySubdirs(path, subdirs); +} + +Status +GetDirectoryFiles( + const std::string& path, const bool skip_hidden_files, + std::set* files) +{ + std::shared_ptr fs; + RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs)); + std::set all_files; + RETURN_IF_ERROR(fs->GetDirectoryFiles(path, &all_files)); + // Remove the hidden files + for (auto f : all_files) { + if ((f[0] != '.') || (!skip_hidden_files)) { + files->insert(f); + } + } + return Status::Success; +} + +Status +ReadTextFile(const std::string& path, std::string* contents) +{ + std::shared_ptr fs; + RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs)); + return fs->ReadTextFile(path, contents); +} + +Status +ReadTextProto(const std::string& path, google::protobuf::Message* msg) +{ + std::shared_ptr fs; + RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs)); + + std::string contents; + RETURN_IF_ERROR(fs->ReadTextFile(path, &contents)); + + if (!google::protobuf::TextFormat::ParseFromString(contents, msg)) { + return Status( + Status::Code::INTERNAL, "failed to read text proto from " + path); + } + + return Status::Success; +} + +Status +LocalizePath(const std::string& path, std::shared_ptr* localized) +{ + std::shared_ptr fs; + RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs)); + return fs->LocalizePath(path, localized); +} + +Status +WriteTextProto(const std::string& path, const google::protobuf::Message& msg) +{ + std::shared_ptr fs; + RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs)); + + std::string prototxt; + if (!google::protobuf::TextFormat::PrintToString(msg, &prototxt)) { + return Status( + Status::Code::INTERNAL, "failed to write text proto to " + path); + } + + return fs->WriteTextFile(path, prototxt); +} + +Status +WriteBinaryFile( + const std::string& path, const char* contents, const size_t content_len) +{ + std::shared_ptr fs; + RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs)); + return fs->WriteBinaryFile(path, contents, content_len); +} + +Status +ReadBinaryProto(const std::string& path, google::protobuf::MessageLite* msg) +{ + std::string msg_str; + RETURN_IF_ERROR(ReadTextFile(path, &msg_str)); + + google::protobuf::io::CodedInputStream coded_stream( + reinterpret_cast(msg_str.c_str()), msg_str.size()); + coded_stream.SetTotalBytesLimit(INT_MAX); + if (!msg->ParseFromCodedStream(&coded_stream)) { + return Status( + Status::Code::INTERNAL, "Can't parse " + path + " as binary proto"); + } + + return Status::Success; +} + +Status +MakeDirectory(const std::string& dir, const bool recursive) +{ + std::shared_ptr fs; + RETURN_IF_ERROR(fsm_.GetFileSystem(dir, fs)); + return fs->MakeDirectory(dir, recursive); +} + +Status +MakeTemporaryDirectory(const FileSystemType type, std::string* temp_dir) +{ + std::shared_ptr fs; + RETURN_IF_ERROR(fsm_.GetFileSystem(type, fs)); + return fs->MakeTemporaryDirectory(temp_dir); +} + +Status +DeletePath(const std::string& path) +{ + std::shared_ptr fs; + RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs)); + return fs->DeletePath(path); +} + +Status +GetFileSystemType(const std::string& path, FileSystemType* type) +{ + if (path.empty()) { + return Status( + Status::Code::INVALID_ARG, + "Can not infer filesystem type from empty path"); + } +#ifdef TRITON_ENABLE_GCS + // Check if this is a GCS path (gs://$BUCKET_NAME) + if (!path.rfind("gs://", 0)) { + *type = FileSystemType::GCS; + return Status::Success; + } +#endif // TRITON_ENABLE_GCS + +#ifdef TRITON_ENABLE_S3 + // Check if this is an S3 path (s3://$BUCKET_NAME) + if (!path.rfind("s3://", 0)) { + *type = FileSystemType::S3; + return Status::Success; + } +#endif // TRITON_ENABLE_S3 + +#ifdef TRITON_ENABLE_AZURE_STORAGE + // Check if this is an Azure Storage path + if (!path.rfind("as://", 0)) { + *type = FileSystemType::AS; + return Status::Success; + } +#endif // TRITON_ENABLE_AZURE_STORAGE + + // Assume path is for local filesystem + *type = FileSystemType::LOCAL; + return Status::Success; +} + +const std::string& +FileSystemTypeString(const FileSystemType type) +{ + static const std::string local_str("LOCAL"); + static const std::string gcs_str("GCS"); + static const std::string s3_str("S3"); + static const std::string as_str("AS"); + static const std::string unknown_str("UNKNOWN"); + switch (type) { + case FileSystemType::LOCAL: + return local_str; + case FileSystemType::GCS: + return gcs_str; + case FileSystemType::S3: + return s3_str; + case FileSystemType::AS: + return as_str; + default: + return unknown_str; + } +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/filesystem.h b/3rdparty/core-r22.12/src/filesystem.h new file mode 100644 index 0000000000000000000000000000000000000000..439e42570c1d635725c2ec7f1ee6e0a62ae41673 --- /dev/null +++ b/3rdparty/core-r22.12/src/filesystem.h @@ -0,0 +1,224 @@ +// Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#ifdef _WIN32 +// Remove GetObject definition from windows.h, which can cause +// a naming collision when GetObject is called. +// https://github.com/Tencent/rapidjson/issues/1448 +#undef GetObject +#endif // _WIN32 + +#include +#include "google/protobuf/message.h" +#include "status.h" + +namespace triton { namespace core { + +enum class FileSystemType { LOCAL, GCS, S3, AS }; + +// This class stores the paths of local temporary files needed for loading +// models from Cloud repositories and performs necessary cleanup after the +// models are loaded. +class LocalizedPath { + public: + // Create an object for a path that is already local. + LocalizedPath(const std::string& original_path) + : original_path_(original_path) + { + } + + // Create an object for a remote path. Store both the original path and the + // temporary local path. + LocalizedPath( + const std::string& original_path, const std::string& local_path) + : original_path_(original_path), local_path_(local_path) + { + } + + // Destructor. Remove temporary local storage associated with the object. + // If the local path is a directory, delete the directory. + // If the local path is a file, delete the directory containing the file. + ~LocalizedPath(); + + // Return the localized path represented by this object. + const std::string& Path() const + { + return (local_path_.empty()) ? original_path_ : local_path_; + } + + // Maintain a vector of LocalizedPath that should be kept available in the + // tmp directory for the lifetime of this object + // FIXME: Remove when no longer required + std::vector> other_localized_path; + + private: + std::string original_path_; + std::string local_path_; +}; + +/// Is a path an absolute path? +/// \param path The path. +/// \return true if absolute path, false if relative path. +bool IsAbsolutePath(const std::string& path); + +/// Join path segments into a longer path +/// \param segments The path segments. +/// \return the path formed by joining the segments. +std::string JoinPath(std::initializer_list segments); + +/// Get the basename of a path. +/// \param path The path. +/// \return the last segment of the path. +std::string BaseName(const std::string& path); + +/// Get the dirname of a path. +/// \param path The path. +/// \return all but the last segment of the path. +std::string DirName(const std::string& path); + +/// Does a file or directory exist? +/// \param path The path to check for existance. +/// \param exists Returns true if file/dir exists +/// \return Error status if unable to perform the check +Status FileExists(const std::string& path, bool* exists); + +/// Is a path a directory? +/// \param path The path to check. +/// \param is_dir Returns true if path represents a directory +/// \return Error status +Status IsDirectory(const std::string& path, bool* is_dir); + +/// Get file modification time in nanoseconds. +/// A file is considered modified in Triton when its binary content has changed +/// including the action of replacing it with another file. +/// \param path The path. +/// \param mtime_ns Returns the file modification time. For some filesystems a +/// file/folder may not have a modification time, in that case return 0. +/// \return Error status +Status FileModificationTime(const std::string& path, int64_t* mtime_ns); + +/// Get the contents of a directory. +/// \param path The directory path. +/// \param subdirs Returns the directory contents. +/// \return Error status +Status GetDirectoryContents( + const std::string& path, std::set* contents); + +/// Get the sub-directories of a path. +/// \param path The path. +/// \param subdirs Returns the names of the sub-directories. +/// \return Error status +Status GetDirectorySubdirs( + const std::string& path, std::set* subdirs); + +/// Get the files contained in a directory. +/// \param path The directory. +/// \param skip_hidden_files Ignores the hidden files in the directory. +/// \param files Returns the names of the files. +/// \return Error status +Status GetDirectoryFiles( + const std::string& path, const bool skip_hidden_files, + std::set* files); + +/// Read a text file into a string. +/// \param path The path of the file. +/// \param contents Returns the contents of the file. +/// \return Error status +Status ReadTextFile(const std::string& path, std::string* contents); + +/// Create an object representing a local copy of a path. +/// \param path The path of the directory or file. +/// \param localized Returns the LocalizedPath object +/// representing the local copy of the path. +/// \return Error status +Status LocalizePath( + const std::string& path, std::shared_ptr* localized); + +/// Write a string to a file. +/// \param path The path of the file. +/// \param contents The contents to write to the file. +/// \return Error status +Status WriteTextFile(const std::string& path, const std::string& contents); + +/// Write binary to a file. +/// \param path The path of the file. +/// \param contents The contents to write to the file. +/// \param content_len The size of the content. +/// \return Error status +Status WriteBinaryFile( + const std::string& path, const char* contents, const size_t content_len); + +/// Read a prototext file. +/// \param path The path of the file. +/// \param msg Returns the protobuf message for the file. +/// \return Error status +Status ReadTextProto(const std::string& path, google::protobuf::Message* msg); + +/// Write a prototext file. +/// \param path The path of the file. +/// \param msg The protobuf to write. +/// \return Error status +Status WriteTextProto( + const std::string& path, const google::protobuf::Message& msg); + +/// Read a binary protobuf file. +/// \param path The path of the file. +/// \param msg Returns the protobuf message for the file. +/// \return Error status +Status ReadBinaryProto( + const std::string& path, google::protobuf::MessageLite* msg); + +/// Create a directory of the specified path. +/// \param dir The path to the directory. +/// \param recursive Whether the parent directories will be created +/// if not exist. +/// \return Error status if the directory can't be created +Status MakeDirectory(const std::string& dir, const bool recursive); + +/// Create a temporary directory of the specified filesystem type. +/// \param type The type of the filesystem. +/// \param temp_dir Returns the path to the temporary directory. +/// \return Error status +Status MakeTemporaryDirectory(const FileSystemType type, std::string* temp_dir); + +/// Delete a path. +/// \param path The path to the directory or file. +/// \return Error status +Status DeletePath(const std::string& path); + +/// Infer the filesystem type from the given path. +/// \param path The path to infer the filesystem type from. +/// \param type Returns the filesystem type of the path. +/// \return Error status +Status GetFileSystemType(const std::string& path, FileSystemType* type); + +/// Return the string representation of the filesystem type. +/// \param type The filesystem type. +/// \return The string representation of the type. +const std::string& FileSystemTypeString(const FileSystemType type); + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/infer_parameter.cc b/3rdparty/core-r22.12/src/infer_parameter.cc new file mode 100644 index 0000000000000000000000000000000000000000..49f8c494a30d953724223d2c1ed8c5fabafa2945 --- /dev/null +++ b/3rdparty/core-r22.12/src/infer_parameter.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "infer_parameter.h" + +namespace triton { namespace core { + + +const void* +InferenceParameter::ValuePointer() const +{ + switch (type_) { + case TRITONSERVER_PARAMETER_STRING: + return reinterpret_cast(value_string_.c_str()); + case TRITONSERVER_PARAMETER_INT: + return reinterpret_cast(&value_int64_); + case TRITONSERVER_PARAMETER_BOOL: + return reinterpret_cast(&value_bool_); + case TRITONSERVER_PARAMETER_BYTES: + return reinterpret_cast(value_bytes_); + default: + break; + } + + return nullptr; +} + +std::ostream& +operator<<(std::ostream& out, const InferenceParameter& parameter) +{ + out << "[0x" << std::addressof(parameter) << "] " + << "name: " << parameter.Name() + << ", type: " << TRITONSERVER_ParameterTypeString(parameter.Type()) + << ", value: "; + return out; +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/infer_parameter.h b/3rdparty/core-r22.12/src/infer_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..0e5b758016ace701d98b782b3cfcb8dfb4c2ec1f --- /dev/null +++ b/3rdparty/core-r22.12/src/infer_parameter.h @@ -0,0 +1,102 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include "tritonserver_apis.h" + +namespace triton { namespace core { + +// +// An inference parameter. +// +class InferenceParameter { + public: + InferenceParameter(const char* name, const char* value) + : name_(name), type_(TRITONSERVER_PARAMETER_STRING), value_string_(value) + { + byte_size_ = value_string_.size(); + } + + InferenceParameter(const char* name, const int64_t value) + : name_(name), type_(TRITONSERVER_PARAMETER_INT), value_int64_(value), + byte_size_(sizeof(int64_t)) + { + } + + InferenceParameter(const char* name, const bool value) + : name_(name), type_(TRITONSERVER_PARAMETER_BOOL), value_bool_(value), + byte_size_(sizeof(bool)) + { + } + + InferenceParameter(const char* name, const void* ptr, const uint64_t size) + : name_(name), type_(TRITONSERVER_PARAMETER_BYTES), value_bytes_(ptr), + byte_size_(size) + { + } + + // The name of the parameter. + const std::string& Name() const { return name_; } + + // Data type of the parameter. + TRITONSERVER_ParameterType Type() const { return type_; } + + // Return a pointer to the parameter, or a pointer to the data content + // if type_ is TRITONSERVER_PARAMETER_BYTES. This returned pointer must be + // cast correctly based on 'type_'. + // TRITONSERVER_PARAMETER_STRING -> const char* + // TRITONSERVER_PARAMETER_INT -> int64_t* + // TRITONSERVER_PARAMETER_BOOL -> bool* + // TRITONSERVER_PARAMETER_BYTES -> const void* + const void* ValuePointer() const; + + // Return the data byte size of the parameter. + uint64_t ValueByteSize() const { return byte_size_; } + + // Return the parameter value string, the return value is valid only if + // Type() returns TRITONSERVER_PARAMETER_STRING + const std::string& ValueString() const { return value_string_; } + + private: + friend std::ostream& operator<<( + std::ostream& out, const InferenceParameter& parameter); + + std::string name_; + TRITONSERVER_ParameterType type_; + + std::string value_string_; + int64_t value_int64_; + bool value_bool_; + const void* value_bytes_; + uint64_t byte_size_; +}; + +std::ostream& operator<<( + std::ostream& out, const InferenceParameter& parameter); + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/infer_request.cc b/3rdparty/core-r22.12/src/infer_request.cc new file mode 100644 index 0000000000000000000000000000000000000000..149fe84527d290961c37cb2c50a49fc2778cd972 --- /dev/null +++ b/3rdparty/core-r22.12/src/infer_request.cc @@ -0,0 +1,1498 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "infer_request.h" + +#include +#include +#include "model.h" +#include "model_config_utils.h" +#include "server.h" +#include "triton/common/logging.h" +#ifdef TRITON_ENABLE_TRACING +#include "cuda_utils.h" +#endif // TRITON_ENABLE_TRACING + +namespace triton { namespace core { + +namespace { + +// Utilities for Null request feature. +TRITONSERVER_Error* +NullResponseAlloc( + TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name, + size_t byte_size, TRITONSERVER_MemoryType preferred_memory_type, + int64_t preferred_memory_type_id, void* userp, void** buffer, + void** buffer_userp, TRITONSERVER_MemoryType* actual_memory_type, + int64_t* actual_memory_type_id) +{ + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "unexpected allocation for null request, no output should be requested."); +} + +TRITONSERVER_Error* +NullResponseRelease( + TRITONSERVER_ResponseAllocator* allocator, void* buffer, void* buffer_userp, + size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id) +{ + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "unexpected release for null request, no output should be requested."); +} + +ResponseAllocator null_allocator = ResponseAllocator( + NullResponseAlloc, NullResponseRelease, nullptr /* start_fn */); + +void +NullResponseComplete( + TRITONSERVER_InferenceResponse* iresponse, const uint32_t flags, + void* userp) +{ + if (iresponse != nullptr) { + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceResponseDelete(iresponse), + "deleting null response"); + } +} + +void +NullRequestComplete( + TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp) +{ + if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) { + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceRequestDelete(request), "deleting null request"); + } +} + +} // namespace + +InferenceRequest::InferenceRequest( + const std::shared_ptr& model, const int64_t requested_model_version) + : InferenceRequest(model.get(), requested_model_version) +{ + model_shared_ = model; +} + +InferenceRequest::InferenceRequest( + Model* model, const int64_t requested_model_version) + : needs_normalization_(true), model_raw_(model), + requested_model_version_(requested_model_version), flags_(0), + correlation_id_(0), batch_size_(0), timeout_us_(0), collect_stats_(true) +{ + SetPriority(0); +} + +const std::string& +InferenceRequest::ModelName() const +{ + return model_raw_->Name(); +} + +int64_t +InferenceRequest::ActualModelVersion() const +{ + return model_raw_->Version(); +} + +void +InferenceRequest::SetPriority(uint32_t p) +{ + if ((p == 0) || (p > model_raw_->MaxPriorityLevel())) { + priority_ = model_raw_->DefaultPriorityLevel(); + } else { + priority_ = p; + } +} + +#ifdef TRITON_ENABLE_TRACING +Status +InferenceRequest::TraceInputTensors( + TRITONSERVER_InferenceTraceActivity activity, const std::string& msg) +{ + const auto& inputs = this->ImmutableInputs(); + TRITONSERVER_MemoryType dst_memory_type = TRITONSERVER_MEMORY_CPU; + int64_t dst_memory_type_id = 0; + + for (const auto& pr : inputs) { + InferenceRequest::Input* ti = pr.second; + + // input data + const std::string& name = ti->Name(); + TRITONSERVER_DataType datatype = DataTypeToTriton(ti->DType()); + uint64_t byte_size = ti->Data()->TotalByteSize(); + const int64_t* shape = ti->ShapeWithBatchDim().data(); + uint32_t dim_count = ti->ShapeWithBatchDim().size(); + uint32_t buffer_count = ti->DataBufferCount(); + // chunk buffer + Status status; + const void* buffer; + uint64_t buffer_size; + TRITONSERVER_MemoryType src_memory_type; + int64_t src_memory_type_id; + bool cuda_used; + + if (buffer_count == 0) { + LOG_STATUS_ERROR( + status, LogRequest() + + TRITONSERVER_InferenceTraceActivityString(activity) + + ": " + msg + ": tensor: " + name + ": no buffer chunk"); + continue; + } + + if (buffer_count == 1) { + status = ti->DataBuffer( + 0, &buffer, &buffer_size, &src_memory_type, &src_memory_type_id); + if (!status.IsOk()) { + LOG_STATUS_ERROR( + status, LogRequest() + + TRITONSERVER_InferenceTraceActivityString(activity) + + ": " + msg + ": tensor: " + name + + ": fail to get data buffer: " + status.Message()); + return status; + } + + if (buffer_size != byte_size) { + LOG_STATUS_ERROR( + status, + LogRequest() + TRITONSERVER_InferenceTraceActivityString(activity) + + ": " + msg + ": tensor: " + name + ": truncated buffer"); + continue; + } + + INFER_TRACE_TENSOR_ACTIVITY( + this->trace_, activity, name.c_str(), datatype, + const_cast(buffer), buffer_size, shape, dim_count, + src_memory_type, src_memory_type_id); + + continue; + } + + // input buffer + std::vector in_buffer(byte_size); + char* base = in_buffer.data(); + size_t offset = 0; + for (uint32_t b = 0; b < buffer_count; ++b) { + status = ti->DataBuffer( + b, &buffer, &buffer_size, &src_memory_type, &src_memory_type_id); + if (!status.IsOk()) { + LOG_STATUS_ERROR( + status, LogRequest() + + TRITONSERVER_InferenceTraceActivityString(activity) + + ": " + msg + ": tensor: " + name + + ": fail to get data buffer: " + status.Message()); + return status; + } + + status = CopyBuffer( + "InferenceRequest TraceInputTensors", src_memory_type, + src_memory_type_id, dst_memory_type, dst_memory_type_id, buffer_size, + buffer, base + offset, nullptr, &cuda_used); + if (!status.IsOk()) { + LOG_STATUS_ERROR( + status, LogRequest() + + TRITONSERVER_InferenceTraceActivityString(activity) + + ": " + msg + ": tensor: " + name + + ": fail to copy buffer: " + status.Message()); + return status; + } + + offset += buffer_size; + } + + INFER_TRACE_TENSOR_ACTIVITY( + this->trace_, activity, name.c_str(), datatype, + static_cast(base), byte_size, shape, dim_count, dst_memory_type, + dst_memory_type_id); + } + + return Status::Success; +} +#endif // TRITON_ENABLE_TRACING + +Status +InferenceRequest::OutputBufferProperties( + const char* name, size_t* byte_size, TRITONSERVER_MemoryType* memory_type, + int64_t* memory_type_id) +{ + const auto allocator = response_factory_->Allocator(); + if ((allocator == nullptr) || (allocator->QueryFn() == nullptr)) { + return Status( + Status::Code::UNAVAILABLE, + (LogRequest() + "Output properties are not available").c_str()); + } else { + RETURN_IF_TRITONSERVER_ERROR(allocator->QueryFn()( + reinterpret_cast( + const_cast(allocator)), + response_factory_->AllocatorUserp(), name, byte_size, memory_type, + memory_type_id)); + } + return Status::Success; +} + +Status +InferenceRequest::Run(std::unique_ptr& request) +{ + return request->model_raw_->Enqueue(request); +} + +void +InferenceRequest::RespondIfError( + std::unique_ptr& request, const Status& status, + const bool release_request) +{ + if (status.IsOk()) { + return; + } + + // Use the response factory to create a response, set the status, + // and send it. If something goes wrong all we can do is log the + // error. Because this is sending an error we assume that this is + // the last response for the request and so set the FINAL flag. + std::unique_ptr response; + LOG_STATUS_ERROR( + request->response_factory_->CreateResponse(&response), + (request->LogRequest() + "failed to create error response").c_str()); + LOG_STATUS_ERROR( + InferenceResponse::SendWithStatus( + std::move(response), TRITONSERVER_RESPONSE_COMPLETE_FINAL, status), + (request->LogRequest() + "failed to send error response").c_str()); + + // If releasing the request then invoke the release callback which + // gives ownership to the callback. So can't access 'request' after + // this point. + if (release_request) { + InferenceRequest::Release( + std::move(request), TRITONSERVER_REQUEST_RELEASE_ALL); + } +} + +void +InferenceRequest::RespondIfError( + std::vector>& requests, + const Status& status, const bool release_requests) +{ + if (status.IsOk()) { + return; + } + + for (auto& request : requests) { + RespondIfError(request, status, release_requests); + } +} + +void +InferenceRequest::Release( + std::unique_ptr&& request, const uint32_t release_flags) +{ + // Invoke the release callbacks added internally before releasing the + // request to user provided callback. + for (auto it = request->release_callbacks_.rbegin(); + it != request->release_callbacks_.rend(); it++) { + (*it)(); + } + request->release_callbacks_.clear(); + +#ifdef TRITON_ENABLE_TRACING + // If tracing then record request end and release the trace. + // This must be before the request callback to ensure the trace + // is properly layered, as the request may be nested in an ensemble + // and the callback may interact with upper level trace. + if (request->trace_ != nullptr) { + request->trace_->ReportNow(TRITONSERVER_TRACE_REQUEST_END); + request->ReleaseTrace(); + } +#endif // TRITON_ENABLE_TRACING + + void* userp = request->release_userp_; + auto& release_fn = request->release_fn_; + release_fn( + reinterpret_cast(request.release()), + release_flags, userp); +} + +InferenceRequest* +InferenceRequest::CopyAsNull(const InferenceRequest& from) +{ + // Create a copy of 'from' request with artifical inputs and no requested + // outputs. Maybe more efficient to share inputs and other metadata, + // but that binds the Null request with 'from' request's lifecycle. + std::unique_ptr lrequest( + new InferenceRequest(from.model_raw_, from.requested_model_version_)); + lrequest->needs_normalization_ = false; + lrequest->batch_size_ = from.batch_size_; + lrequest->collect_stats_ = false; + + // Three passes: first to construct input for the shape tensors inputs, second + // to obtain the max input byte size for allocating a large enough buffer for + // all non shape tensor inputs; third to construct the inputs for these + // tensors. + // First pass + for (const auto& input : from.OriginalInputs()) { + // Handle only shape tensors in this pass + if (!input.second.IsShapeTensor()) { + continue; + } + + // Prepare the memory to hold input data + size_t byte_size = input.second.Data()->TotalByteSize(); + auto mem_type = TRITONSERVER_MEMORY_CPU; + int64_t mem_id = 0; + std::shared_ptr data = + std::make_shared(byte_size, mem_type, mem_id); + + // Get the source buffer. Assumes shape tensors be in a single buffer on the + // CPU + const auto& from_data = input.second.Data(); + size_t from_data_byte_size; + TRITONSERVER_MemoryType from_data_memory_type; + int64_t from_data_memory_id; + const char* from_data_buffer = from_data->BufferAt( + 0 /* idx */, &from_data_byte_size, &from_data_memory_type, + &from_data_memory_id); + + if (from_data_byte_size != byte_size) { + LOG_WARNING + << lrequest->LogRequest() + << "The byte size of shape tensor to be copied does not match"; + } + + // Copy the shape values to the input buffer + std::memcpy(data->MutableBuffer(), from_data_buffer, from_data_byte_size); + + Input* new_input; + lrequest->AddOriginalInput( + input.first, input.second.DType(), input.second.Shape(), &new_input); + + // Must normalize shape here... + *new_input->MutableShape() = input.second.Shape(); + *new_input->MutableShapeWithBatchDim() = input.second.ShapeWithBatchDim(); + + new_input->SetData(data); + } + + // Second pass + size_t max_byte_size = 0; + size_t max_str_byte_size = 0; + const std::string* max_input_name; + for (const auto& input : from.OriginalInputs()) { + // Skip shape tensors in this pass + if (input.second.IsShapeTensor()) { + continue; + } + + if (input.second.DType() == inference::DataType::TYPE_STRING) { + int64_t element_count = + triton::common::GetElementCount(input.second.Shape()); + + size_t str_byte_size = static_cast(4 * element_count); + max_str_byte_size = std::max(str_byte_size, max_str_byte_size); + if (str_byte_size > max_byte_size) { + max_byte_size = str_byte_size; + max_input_name = &(input.first); + } + } else { + if (input.second.Data()->TotalByteSize() >= max_byte_size) { + max_byte_size = input.second.Data()->TotalByteSize(); + max_input_name = &(input.first); + } + } + } + + // Third pass + // [DLIS-1268] should use one growable static buffer for all null requests + auto mem_type = TRITONSERVER_MEMORY_CPU; + int64_t mem_id = 0; + std::shared_ptr data = + std::make_shared(max_byte_size, mem_type, mem_id); + auto data_base = data->BufferAt(0, &max_byte_size, &mem_type, &mem_id); + + // Zero initialization is only required when there is a TYPE_BYTES tensor in + // the request. Only set the required number of bytes to zero. + if (max_str_byte_size > 0) { + std::fill( + data->MutableBuffer(), data->MutableBuffer() + max_str_byte_size, 0); + } + + for (const auto& input : from.OriginalInputs()) { + // skip shape tensors in this pass + if (input.second.IsShapeTensor()) { + continue; + } + Input* new_input; + lrequest->AddOriginalInput( + input.first, input.second.DType(), input.second.Shape(), &new_input); + + // Must normalize shape here... + *new_input->MutableShape() = input.second.Shape(); + *new_input->MutableShapeWithBatchDim() = input.second.ShapeWithBatchDim(); + + // Note that the input that have max byte size will be responsible for + // holding the artifical data, while other inputs will hold a reference to + // it with byte size that matches 'from' + if (input.first == *max_input_name) { + new_input->SetData(data); + } else { + if (inference::DataType::TYPE_STRING == input.second.DType()) { + new_input->AppendData( + data_base, + triton::common::GetElementCount(input.second.Shape()) * 4, mem_type, + mem_id); + } else { + new_input->AppendData( + data_base, input.second.Data()->TotalByteSize(), mem_type, mem_id); + } + } + } + + // No outputs were requested and thus there should be no allocations. + lrequest->SetResponseCallback( + &null_allocator, nullptr, NullResponseComplete, nullptr); + lrequest->SetReleaseCallback(NullRequestComplete, nullptr); + + // Must normalize inputs here... + for (auto& pr : lrequest->original_inputs_) { + lrequest->inputs_.emplace( + std::make_pair(pr.second.Name(), std::addressof(pr.second))); + } + + return lrequest.release(); +} + +Status +InferenceRequest::MutableOriginalInput( + const std::string& name, InferenceRequest::Input** input) +{ + auto itr = original_inputs_.find(name); + if (itr == original_inputs_.end()) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "input '" + name + "' does not exist in request"); + } + + *input = &(itr->second); + + return Status::Success; +} + +Status +InferenceRequest::ImmutableInput( + const std::string& name, const InferenceRequest::Input** input) const +{ + auto itr = inputs_.find(name); + if (itr == inputs_.end()) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "input '" + name + "' does not exist in request"); + } + + *input = itr->second; + return Status::Success; +} + +Status +InferenceRequest::AddOriginalInput( + const std::string& name, const inference::DataType datatype, + const int64_t* shape, const uint64_t dim_count, + InferenceRequest::Input** input) +{ + const auto& pr = original_inputs_.emplace( + std::piecewise_construct, std::forward_as_tuple(name), + std::forward_as_tuple(name, datatype, shape, dim_count)); + if (!pr.second) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "input '" + name + "' already exists in request"); + } + + if (input != nullptr) { + *input = std::addressof(pr.first->second); + } + + needs_normalization_ = true; + return Status::Success; +} + +Status +InferenceRequest::AddOriginalInput( + const std::string& name, const inference::DataType datatype, + const std::vector& shape, InferenceRequest::Input** input) +{ + return AddOriginalInput(name, datatype, &shape[0], shape.size(), input); +} + +Status +InferenceRequest::AddRawInput( + const std::string& name, InferenceRequest::Input** input) +{ + if (original_inputs_.size() != 0) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "raw input '" + name + + "' can't be added to request with other inputs"); + } + const auto& pr = original_inputs_.emplace( + std::piecewise_construct, std::forward_as_tuple(name), + std::forward_as_tuple()); + if (!pr.second) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "input '" + name + "' already exists in request"); + } + + if (input != nullptr) { + *input = std::addressof(pr.first->second); + } + + raw_input_name_ = name; + needs_normalization_ = true; + return Status::Success; +} + +Status +InferenceRequest::RemoveOriginalInput(const std::string& name) +{ + if (original_inputs_.erase(name) != 1) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "input '" + name + "' does not exist in request"); + } + + if (name == raw_input_name_) { + raw_input_name_.clear(); + } + needs_normalization_ = true; + return Status::Success; +} + +Status +InferenceRequest::RemoveAllOriginalInputs() +{ + original_inputs_.clear(); + raw_input_name_.clear(); + needs_normalization_ = true; + return Status::Success; +} + +Status +InferenceRequest::AddOverrideInput( + const std::string& name, const inference::DataType datatype, + const int64_t batch_size, const std::vector& shape, + std::shared_ptr* input) +{ + std::shared_ptr i = std::make_shared(name, datatype, shape); + *(i->MutableShape()) = i->OriginalShape(); + if (batch_size > 0) { + *(i->MutableShapeWithBatchDim()) = {batch_size}; + i->MutableShapeWithBatchDim()->insert( + i->MutableShapeWithBatchDim()->end(), i->OriginalShape().begin(), + i->OriginalShape().end()); + } else { + *(i->MutableShapeWithBatchDim()) = i->OriginalShape(); + } + + RETURN_IF_ERROR(AddOverrideInput(i)); + if (input != nullptr) { + *input = std::move(i); + } + + return Status::Success; +} + +Status +InferenceRequest::AddOverrideInput( + const std::shared_ptr& input) +{ + LOG_VERBOSE(1) << LogRequest() << "adding input override for " + << input->Name() << ": " << *this; + + const auto& pr = + override_inputs_.emplace(std::make_pair(input->Name(), input)); + if (!pr.second) { + pr.first->second = input; + } + + // Add or replace this override in the inputs... + const auto res = inputs_.emplace(std::make_pair(input->Name(), input.get())); + if (!res.second) { + res.first->second = input.get(); + } + + LOG_VERBOSE(1) << LogRequest() << "added input override for " << input->Name() + << ": " << *this; + + return Status::Success; +} + +Status +InferenceRequest::AddOriginalRequestedOutput(const std::string& name) +{ + original_requested_outputs_.insert(name); + needs_normalization_ = true; + return Status::Success; +} + +Status +InferenceRequest::LoadInputStates() +{ + // Add the input states to the inference request. + if (sequence_states_ != nullptr) { + if (sequence_states_->IsNullRequest()) { + sequence_states_ = + SequenceStates::CopyAsNull(sequence_states_->NullSequenceStates()); + } + for (auto& input_state_pair : sequence_states_->InputStates()) { + auto& input_state = input_state_pair.second; + std::shared_ptr input = + std::make_shared( + input_state->Name(), input_state->DType(), input_state->Shape()); + *input->MutableShapeWithBatchDim() = input_state->Shape(); + input->SetData(input_state->Data()); + AddOverrideInput(input); + } + } + + return Status::Success; +} + +Status +InferenceRequest::RemoveOriginalRequestedOutput(const std::string& name) +{ + original_requested_outputs_.erase(name); + needs_normalization_ = true; + return Status::Success; +} + +Status +InferenceRequest::RemoveAllOriginalRequestedOutputs() +{ + original_requested_outputs_.clear(); + needs_normalization_ = true; + return Status::Success; +} + +Status +InferenceRequest::PrepareForInference() +{ + // Remove override inputs as those are added during any previous + // inference execution. + inputs_.clear(); + override_inputs_.clear(); + + // Renormalize if anything has changed in the inference request in a + // way that could impact renormalization. + if (needs_normalization_) { + RETURN_IF_ERROR(Normalize()); + needs_normalization_ = false; + } + + // Initially show the actual inputs to be only the original + // inputs. If overrides are added later they will be added to + // 'inputs_'. + for (auto& pr : original_inputs_) { + inputs_.emplace( + std::make_pair(pr.second.Name(), std::addressof(pr.second))); + } + + // Clear the timestamps + queue_start_ns_ = 0; + batcher_start_ns_ = 0; +#ifdef TRITON_ENABLE_STATS + request_start_ns_ = 0; +#endif // TRITON_ENABLE_STATS + + LOG_VERBOSE(1) << LogRequest() << "prepared: " << *this; + + return Status::Success; +} + +Status +InferenceRequest::Normalize() +{ + const inference::ModelConfig& model_config = model_raw_->Config(); + + // Fill metadata for raw input + if (!raw_input_name_.empty()) { + const bool has_multiple_inputs = + (original_inputs_.size() != 1) || (model_config.input_size() != 1); + if (has_multiple_inputs) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "Raw request must only have 1 input (found " + + std::to_string(original_inputs_.size()) + + ") to be deduced but got " + + std::to_string(model_config.input_size()) + " inputs in '" + + ModelName() + "' model configuration"); + } + auto it = original_inputs_.begin(); + if (raw_input_name_ != it->first) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "Unexpected reference name for raw input '" + + raw_input_name_ + "' got '" + it->first + "'"); + } + const auto& config_input = model_config.input(0); + auto& raw_input = it->second; + std::vector shape; + if (model_config.max_batch_size() != 0) { + shape.emplace_back(1); + } + int64_t dynamic_axis = -1; + size_t element_cnt = 1; + for (const auto& dim : config_input.dims()) { + if (dim == triton::common::WILDCARD_DIM) { + if (dynamic_axis != -1) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "The shape of the raw input '" + + config_input.name() + + "' can not be deduced because there are more than one " + "variable-sized dimension"); + } + dynamic_axis = shape.size(); + } else { + element_cnt *= (size_t)dim; + } + shape.emplace_back(dim); + } + if ((config_input.data_type() == inference::DataType::TYPE_STRING)) { + const bool has_one_element = (dynamic_axis == -1) && (element_cnt == 1); + if (!has_one_element) { + return Status( + Status::Code::INVALID_ARG, LogRequest() + + "For BYTE datatype raw input, the " + "model must have input shape [1]"); + } + // In the case of BYTE data type, we will prepend the byte size to follow + // the Triton convention. + raw_input_size_ = raw_input.Data()->TotalByteSize(); + RETURN_IF_ERROR(raw_input.PrependData( + &raw_input_size_, sizeof(uint32_t), TRITONSERVER_MEMORY_CPU, 0)); + // Limit the BYTE raw input not to have host policy specific input for + // simplicity, such case won't happen given the current protocol spec. + // Will need to extend Input::PrependData() if needed. + if (!raw_input.HostPolicyData().empty()) { + return Status( + Status::Code::INVALID_ARG, LogRequest() + + "Raw input with data associated " + "with a host policy setting is not " + "currently supported"); + } + } else if (dynamic_axis != -1) { + shape[dynamic_axis] = + raw_input.Data()->TotalByteSize() / element_cnt / + triton::common::GetDataTypeByteSize(config_input.data_type()); + } + raw_input.SetMetadata(config_input.name(), config_input.data_type(), shape); + } + + // Initialize the requested outputs to be used during inference. If + // original_requested_outputs_ is empty assume all outputs specified + // in model config are being requested. + requested_outputs_.clear(); + if (original_requested_outputs_.size() == 0) { + for (const auto& output : model_config.output()) { + requested_outputs_.insert(output.name()); + } + } else { + // Validate if the original requested output name exists in the + // model configuration. + for (const auto& output_name : original_requested_outputs_) { + const inference::ModelOutput* output_config; + RETURN_IF_ERROR(model_raw_->GetOutput(output_name, &output_config)); + } + } + // Make sure that the request is providing the number of inputs + // as is expected by the model. + if ((original_inputs_.size() > (size_t)model_config.input_size()) || + (original_inputs_.size() < model_raw_->RequiredInputCount())) { + // If no input is marked as optional, then use exact match error message + // for consistency / backward compatibility + if ((size_t)model_config.input_size() == model_raw_->RequiredInputCount()) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "expected " + + std::to_string(model_config.input_size()) + " inputs but got " + + std::to_string(original_inputs_.size()) + " inputs for model '" + + ModelName() + "'"); + } else { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "expected number of inputs between " + + std::to_string(model_raw_->RequiredInputCount()) + " and " + + std::to_string(model_config.input_size()) + " but got " + + std::to_string(original_inputs_.size()) + " inputs for model '" + + ModelName() + "'"); + } + } + + // Determine the batch size and shape of each input. + if (model_config.max_batch_size() == 0) { + // Model does not support Triton-style batching so set as + // batch-size 0 and leave the tensor shapes as they are. + batch_size_ = 0; + for (auto& pr : original_inputs_) { + auto& input = pr.second; + *input.MutableShape() = input.OriginalShape(); + } + } else { + // Model does support Triton-style batching so each input tensor + // must have the same first dimension which is the batch + // size. Adjust the shape of the input tensors to remove the batch + // dimension. + batch_size_ = 0; + for (auto& pr : original_inputs_) { + auto& input = pr.second; + + // For a shape tensor, keep the tensor's shape as it is and mark + // that the input is a shape tensor. + const inference::ModelInput* input_config; + RETURN_IF_ERROR(model_raw_->GetInput(input.Name(), &input_config)); + if (input_config->is_shape_tensor()) { + *input.MutableShape() = input.OriginalShape(); + input.SetIsShapeTensor(true); + continue; + } + + if (input.OriginalShape().size() == 0) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "input '" + input.Name() + + "' has no shape but model requires batch dimension for '" + + ModelName() + "'"); + } + + if (batch_size_ == 0) { + batch_size_ = input.OriginalShape()[0]; + } else if (input.OriginalShape()[0] != batch_size_) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "input '" + input.Name() + + "' batch size does not match other inputs for '" + ModelName() + + "'"); + } + + input.MutableShape()->assign( + input.OriginalShape().begin() + 1, input.OriginalShape().end()); + } + } + + // Make sure request batch-size doesn't exceed what is supported by + // the model. + if ((int)batch_size_ > model_config.max_batch_size()) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "inference request batch-size must be <= " + + std::to_string(model_config.max_batch_size()) + " for '" + + ModelName() + "'"); + } + + // Verify that each input shape is valid for the model, make + // adjustments for reshapes and find the total tensor size. + for (auto& pr : original_inputs_) { + const inference::ModelInput* input_config; + RETURN_IF_ERROR(model_raw_->GetInput(pr.second.Name(), &input_config)); + + auto& input = pr.second; + auto shape = input.MutableShape(); + + if (input.DType() != input_config->data_type()) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "inference input data-type is '" + + std::string( + triton::common::DataTypeToProtocolString(input.DType())) + + "', model expects '" + + std::string(triton::common::DataTypeToProtocolString( + input_config->data_type())) + + "' for '" + ModelName() + "'"); + } + + // Validate input shape + { + bool match_config = true; + const auto& config_dims = input_config->dims(); + const auto& input_dims = *shape; + if (config_dims.size() != (int64_t)input_dims.size()) { + match_config = false; + } else { + for (int i = 0; i < config_dims.size(); ++i) { + if (input_dims[i] == triton::common::WILDCARD_DIM) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + + "All input dimensions should be specified for input '" + + pr.first + "' for model '" + ModelName() + "', got " + + triton::common::DimsListToString(input.OriginalShape())); + } else if ( + (config_dims[i] != triton::common::WILDCARD_DIM) && + (config_dims[i] != input_dims[i])) { + match_config = false; + break; + } + } + } + + if (!match_config) { + triton::common::DimsList full_dims; + if (model_config.max_batch_size() > 0) { + full_dims.Add(triton::common::WILDCARD_DIM); + } + for (int i = 0; i < input_config->dims_size(); ++i) { + full_dims.Add(input_config->dims(i)); + } + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "unexpected shape for input '" + pr.first + + "' for model '" + ModelName() + "'. Expected " + + triton::common::DimsListToString(full_dims) + ", got " + + triton::common::DimsListToString(input.OriginalShape())); + } + } + + // If there is a reshape for this input then adjust them to + // match the reshape. As reshape may have variable-size + // dimensions, we need to record corresponding value so that we + // can set the value correctly for reshape. + if (input_config->has_reshape()) { + std::deque variable_size_values; + for (int64_t idx = 0; idx < input_config->dims_size(); idx++) { + if (input_config->dims(idx) == -1) { + variable_size_values.push_back((*shape)[idx]); + } + } + + shape->clear(); + for (const auto& dim : input_config->reshape().shape()) { + if (dim == -1) { + shape->push_back(variable_size_values.front()); + variable_size_values.pop_front(); + } else { + shape->push_back(dim); + } + } + } + + // Create shape with batch dimension. + // FIXME, should not need this!! + if (batch_size_ == 0) { + *input.MutableShapeWithBatchDim() = *shape; + } else { + input.MutableShapeWithBatchDim()->clear(); + input.MutableShapeWithBatchDim()->push_back(batch_size_); + for (int64_t d : *shape) { + input.MutableShapeWithBatchDim()->push_back(d); + } + } + } + + return Status::Success; +} + +#ifdef TRITON_ENABLE_STATS +void +InferenceRequest::ReportStatistics( + MetricModelReporter* metric_reporter, bool success, + const uint64_t compute_start_ns, const uint64_t compute_input_end_ns, + const uint64_t compute_output_start_ns, const uint64_t compute_end_ns) +{ + if (!collect_stats_) { + return; + } + +#ifdef TRITON_ENABLE_TRACING + if (trace_ != nullptr) { + trace_->Report(TRITONSERVER_TRACE_COMPUTE_START, compute_start_ns); + trace_->Report(TRITONSERVER_TRACE_COMPUTE_INPUT_END, compute_input_end_ns); + trace_->Report( + TRITONSERVER_TRACE_COMPUTE_OUTPUT_START, compute_output_start_ns); + trace_->Report(TRITONSERVER_TRACE_COMPUTE_END, compute_end_ns); + } +#endif // TRITON_ENABLE_TRACING + + INFER_STATS_DECL_TIMESTAMP(request_end_ns); + + if (success) { + model_raw_->MutableStatsAggregator()->UpdateSuccess( + metric_reporter, std::max(1U, batch_size_), request_start_ns_, + queue_start_ns_, compute_start_ns, compute_input_end_ns, + compute_output_start_ns, compute_end_ns, request_end_ns); + if (secondary_stats_aggregator_ != nullptr) { + secondary_stats_aggregator_->UpdateSuccess( + nullptr /* metric_reporter */, std::max(1U, batch_size_), + request_start_ns_, queue_start_ns_, compute_start_ns, + compute_input_end_ns, compute_output_start_ns, compute_end_ns, + request_end_ns); + } + } else { + model_raw_->MutableStatsAggregator()->UpdateFailure( + metric_reporter, request_start_ns_, request_end_ns); + if (secondary_stats_aggregator_ != nullptr) { + secondary_stats_aggregator_->UpdateFailure( + nullptr /* metric_reporter */, request_start_ns_, request_end_ns); + } + } +} + +void +InferenceRequest::ReportStatisticsWithDuration( + MetricModelReporter* metric_reporter, bool success, + const uint64_t compute_start_ns, const uint64_t compute_input_duration_ns, + const uint64_t compute_infer_duration_ns, + const uint64_t compute_output_duration_ns) +{ + if (!collect_stats_) { + return; + } + + INFER_STATS_DECL_TIMESTAMP(request_end_ns); + + if (success) { + model_raw_->MutableStatsAggregator()->UpdateSuccessWithDuration( + metric_reporter, std::max(1U, batch_size_), request_start_ns_, + queue_start_ns_, compute_start_ns, request_end_ns, + compute_input_duration_ns, compute_infer_duration_ns, + compute_output_duration_ns); + if (secondary_stats_aggregator_ != nullptr) { + secondary_stats_aggregator_->UpdateSuccessWithDuration( + nullptr /* metric_reporter */, std::max(1U, batch_size_), + request_start_ns_, queue_start_ns_, compute_start_ns, request_end_ns, + compute_input_duration_ns, compute_infer_duration_ns, + compute_output_duration_ns); + } + } else { + model_raw_->MutableStatsAggregator()->UpdateFailure( + metric_reporter, request_start_ns_, request_end_ns); + if (secondary_stats_aggregator_ != nullptr) { + secondary_stats_aggregator_->UpdateFailure( + nullptr /* metric_reporter */, request_start_ns_, request_end_ns); + } + } +} + +void +InferenceRequest::ReportStatisticsCacheHit(MetricModelReporter* metric_reporter) +{ + // Capture end of request time + INFER_STATS_DECL_TIMESTAMP(request_end_ns); + + if (cache_lookup_start_ns_ >= cache_lookup_end_ns_) { + LOG_WARNING << LogRequest() + << "Cache lookup timestamps were not set correctly. Cache " + "lookup duration stats may be incorrect."; + } + const uint64_t cache_lookup_duration_ns = + cache_lookup_end_ns_ - cache_lookup_start_ns_; + + // Cache hit is always success + model_raw_->MutableStatsAggregator()->UpdateSuccessCacheHit( + metric_reporter, std::max(1U, batch_size_), request_start_ns_, + queue_start_ns_, cache_lookup_start_ns_, request_end_ns, + cache_lookup_duration_ns); + if (secondary_stats_aggregator_ != nullptr) { + secondary_stats_aggregator_->UpdateSuccessCacheHit( + nullptr /* metric_reporter */, std::max(1U, batch_size_), + request_start_ns_, queue_start_ns_, cache_lookup_start_ns_, + request_end_ns, cache_lookup_duration_ns); + } +} + +void +InferenceRequest::ReportStatisticsCacheMiss( + MetricModelReporter* metric_reporter) +{ + if (cache_lookup_start_ns_ >= cache_lookup_end_ns_) { + LOG_WARNING << LogRequest() + << "Cache lookup timestamps were not set correctly. Cache " + "lookup duration stats may be incorrect."; + } + if (cache_insertion_start_ns_ >= cache_insertion_end_ns_) { + LOG_WARNING << LogRequest() + << "Cache insertion timestamps were not set correctly. Cache " + "insertion duration stats may be incorrect."; + } + + const uint64_t cache_lookup_duration_ns = + cache_lookup_end_ns_ - cache_lookup_start_ns_; + + const uint64_t cache_insertion_duration_ns = + cache_insertion_end_ns_ - cache_insertion_start_ns_; + + model_raw_->MutableStatsAggregator()->UpdateSuccessCacheMiss( + metric_reporter, cache_lookup_duration_ns, cache_insertion_duration_ns); + if (secondary_stats_aggregator_ != nullptr) { + secondary_stats_aggregator_->UpdateSuccessCacheMiss( + nullptr /* metric_reporter */, cache_lookup_duration_ns, + cache_insertion_duration_ns); + } +} +#endif // TRITON_ENABLE_STATS + +// +// Input +// +InferenceRequest::Input::Input() + : is_shape_tensor_(false), data_(new MemoryReference), + has_host_policy_specific_data_(false) +{ +} + +InferenceRequest::Input::Input( + const std::string& name, const inference::DataType datatype, + const int64_t* shape, const uint64_t dim_count) + : name_(name), datatype_(datatype), + original_shape_(shape, shape + dim_count), is_shape_tensor_(false), + data_(new MemoryReference), has_host_policy_specific_data_(false) +{ +} + +InferenceRequest::Input::Input( + const std::string& name, const inference::DataType datatype, + const std::vector& shape) + : name_(name), datatype_(datatype), original_shape_(shape), + is_shape_tensor_(false), data_(new MemoryReference), + has_host_policy_specific_data_(false) +{ +} + +void +InferenceRequest::Input::SetMetadata( + const std::string& name, const inference::DataType& dt, + const std::vector& shape) +{ + name_ = name; + datatype_ = dt; + original_shape_ = shape; +} + +Status +InferenceRequest::Input::SetIsShapeTensor(const bool is_shape_tensor) +{ + is_shape_tensor_ = is_shape_tensor; + return Status::Success; +} + +const std::shared_ptr& +InferenceRequest::Input::Data(const std::string& host_policy_name) const +{ + auto device_data = host_policy_data_map_.find(host_policy_name); + if (device_data == host_policy_data_map_.end()) { + // Fall back on default data if there is no data that has been added for + // this host policy + return data_; + } + return device_data->second; +} + +Status +InferenceRequest::Input::AppendData( + const void* base, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id) +{ + if (byte_size > 0) { + std::static_pointer_cast(data_)->AddBuffer( + static_cast(base), byte_size, memory_type, memory_type_id); + } + + return Status::Success; +} + +Status +InferenceRequest::Input::AppendDataWithBufferAttributes( + const void* base, BufferAttributes* buffer_attributes) +{ + if (buffer_attributes->ByteSize() > 0) { + std::static_pointer_cast(data_)->AddBuffer( + static_cast(base), buffer_attributes); + } + return Status::Success; +} + +Status +InferenceRequest::Input::AppendDataWithHostPolicy( + const void* base, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id, const char* host_policy_name) +{ + auto device_data = host_policy_data_map_.find(host_policy_name); + has_host_policy_specific_data_ = true; + if (device_data == host_policy_data_map_.end()) { + auto insert_pair = host_policy_data_map_.insert( + std::make_pair(std::string(host_policy_name), new MemoryReference)); + device_data = insert_pair.first; + } + if (byte_size > 0) { + std::static_pointer_cast(device_data->second) + ->AddBuffer( + static_cast(base), byte_size, memory_type, + memory_type_id); + } + + return Status::Success; +} + +Status +InferenceRequest::Input::PrependData( + const void* base, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id) +{ + if (byte_size > 0) { + std::static_pointer_cast(data_)->AddBufferFront( + static_cast(base), byte_size, memory_type, memory_type_id); + } + + return Status::Success; +} + +Status +InferenceRequest::Input::SetData(const std::shared_ptr& data) +{ + if (data_->TotalByteSize() != 0) { + return Status( + Status::Code::INVALID_ARG, + "input '" + name_ + "' already has data, can't overwrite"); + } + + data_ = data; + + return Status::Success; +} + +Status +InferenceRequest::Input::SetData( + const std::string& host_policy_name, const std::shared_ptr& data) +{ + if (host_policy_data_map_.find(host_policy_name) != + host_policy_data_map_.end()) { + return Status( + Status::Code::INVALID_ARG, "input '" + name_ + + "' already has data for host policy '" + + host_policy_name + "', can't overwrite"); + } + + host_policy_data_map_.emplace(host_policy_name, data); + + return Status::Success; +} + +Status +InferenceRequest::Input::RemoveAllData() +{ + data_ = std::make_shared(); + host_policy_data_map_.clear(); + has_host_policy_specific_data_ = false; + return Status::Success; +} + +Status +InferenceRequest::Input::DataBuffer( + const size_t idx, const void** base, size_t* byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id) const +{ + *base = data_->BufferAt(idx, byte_size, memory_type, memory_type_id); + + return Status::Success; +} + +Status +InferenceRequest::Input::DataBufferAttributes( + const size_t idx, const void** base, + BufferAttributes** buffer_attributes) const +{ + *base = data_->BufferAt(idx, buffer_attributes); + + return Status::Success; +} + +Status +InferenceRequest::Input::DataBufferForHostPolicy( + const size_t idx, const void** base, size_t* byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id, + const std::string& host_policy_name) const +{ + auto device_data = host_policy_data_map_.find(host_policy_name); + if (device_data == host_policy_data_map_.end()) { + // Return data buffer if there is no host-policy specific buffer available + *base = data_->BufferAt(idx, byte_size, memory_type, memory_type_id); + } else { + *base = device_data->second->BufferAt( + idx, byte_size, memory_type, memory_type_id); + } + + return Status::Success; +} + +size_t +InferenceRequest::Input::DataBufferCountForHostPolicy( + const std::string& host_policy_name) const +{ + auto policy_data = host_policy_data_map_.find(host_policy_name); + if (policy_data != host_policy_data_map_.end()) { + return policy_data->second->BufferCount(); + } + return data_->BufferCount(); +} + +InferenceRequest::SequenceId::SequenceId() + : sequence_label_(""), sequence_index_(0), + id_type_(InferenceRequest::SequenceId::DataType::UINT64) +{ +} + +InferenceRequest::SequenceId::SequenceId(const std::string& sequence_label) + : sequence_label_(sequence_label), sequence_index_(0), + id_type_(InferenceRequest::SequenceId::DataType::STRING) +{ +} + +InferenceRequest::SequenceId::SequenceId(uint64_t sequence_index) + : sequence_label_(""), sequence_index_(sequence_index), + id_type_(InferenceRequest::SequenceId::DataType::UINT64) +{ +} + +InferenceRequest::SequenceId& +InferenceRequest::SequenceId::operator=(const std::string& rhs) +{ + sequence_label_ = rhs; + sequence_index_ = 0; + id_type_ = InferenceRequest::SequenceId::DataType::STRING; + return *this; +} + +InferenceRequest::SequenceId& +InferenceRequest::SequenceId::operator=(const uint64_t rhs) +{ + sequence_label_ = ""; + sequence_index_ = rhs; + id_type_ = InferenceRequest::SequenceId::DataType::UINT64; + return *this; +} + +std::ostream& +operator<<(std::ostream& out, const InferenceRequest& request) +{ + out << "[0x" << std::addressof(request) << "] " + << "request id: " << request.Id() << ", model: " << request.ModelName() + << ", requested version: " << request.RequestedModelVersion() + << ", actual version: " << request.ActualModelVersion() << ", flags: 0x" + << std::hex << request.Flags() << std::dec + << ", correlation id: " << request.CorrelationId() + << ", batch size: " << request.BatchSize() + << ", priority: " << request.Priority() + << ", timeout (us): " << request.TimeoutMicroseconds() << std::endl; + + out << "original inputs:" << std::endl; + for (const auto& itr : request.OriginalInputs()) { + out << "[0x" << std::addressof(itr.second) << "] " << itr.second + << std::endl; + } + + out << "override inputs:" << std::endl; + for (const auto& itr : request.OverrideInputs()) { + out << "[0x" << itr.second.get() << "] " << *itr.second << std::endl; + } + + out << "inputs:" << std::endl; + for (const auto& itr : request.ImmutableInputs()) { + out << "[0x" << itr.second << "] " << *itr.second << std::endl; + } + + out << "original requested outputs:" << std::endl; + for (const auto& name : request.OriginalRequestedOutputs()) { + out << name << std::endl; + } + + out << "requested outputs:" << std::endl; + for (const auto& name : request.ImmutableRequestedOutputs()) { + out << name << std::endl; + } + + return out; +} + +std::ostream& +operator<<(std::ostream& out, const InferenceRequest::Input& input) +{ + out << "input: " << input.Name() + << ", type: " << triton::common::DataTypeToProtocolString(input.DType()) + << ", original shape: " + << triton::common::DimsListToString(input.OriginalShape()) + << ", batch + shape: " + << triton::common::DimsListToString(input.ShapeWithBatchDim()) + << ", shape: " << triton::common::DimsListToString(input.Shape()); + if (input.IsShapeTensor()) { + out << ", is_shape_tensor: True"; + } + return out; +} + +std::ostream& +operator<<(std::ostream& out, const InferenceRequest::SequenceId& sequence_id) +{ + switch (sequence_id.Type()) { + case InferenceRequest::SequenceId::DataType::STRING: + out << sequence_id.StringValue(); + break; + case InferenceRequest::SequenceId::DataType::UINT64: + out << sequence_id.UnsignedIntValue(); + break; + default: + out << sequence_id.UnsignedIntValue(); + break; + } + return out; +} + +bool +operator==( + const InferenceRequest::SequenceId lhs, + const InferenceRequest::SequenceId rhs) +{ + if (lhs.Type() == rhs.Type()) { + switch (lhs.Type()) { + case InferenceRequest::SequenceId::DataType::STRING: + return lhs.StringValue() == rhs.StringValue(); + case InferenceRequest::SequenceId::DataType::UINT64: + return lhs.UnsignedIntValue() == rhs.UnsignedIntValue(); + default: + return lhs.UnsignedIntValue() == rhs.UnsignedIntValue(); + } + } else { + return false; + } +} + +bool +operator!=( + const InferenceRequest::SequenceId lhs, + const InferenceRequest::SequenceId rhs) +{ + return !(lhs == rhs); +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/infer_request.h b/3rdparty/core-r22.12/src/infer_request.h new file mode 100644 index 0000000000000000000000000000000000000000..0dba6aa45ab255bf39c5fbbfc5a0d0d5ac3ea659 --- /dev/null +++ b/3rdparty/core-r22.12/src/infer_request.h @@ -0,0 +1,800 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include +#include "buffer_attributes.h" +#include "infer_response.h" +#include "infer_stats.h" +#include "infer_trace.h" +#include "memory.h" +#include "response_allocator.h" +#include "sequence_state.h" +#include "status.h" +#include "triton/common/model_config.h" +#include "tritonserver_apis.h" + +namespace triton { namespace core { + +class Model; +class InferenceServer; +class MetricModelReporter; + +// +// An inference request. A request can be used multiple times for +// inference but before each inference run, PrepareForInference() must +// be called to verify and prepare the request. Verification involves +// ensuring that any changes made since the last inference are +// valid. Preparing involves removing/resetting any state left over +// from the previous inference. +// +class InferenceRequest { + public: + // Input tensor + class Input { + public: + Input(); + Input( + const std::string& name, const inference::DataType datatype, + const std::vector& shape); + Input( + const std::string& name, const inference::DataType datatype, + const int64_t* shape, const uint64_t dim_count); + + // Set the name, data type and original shape of the input tensor. + void SetMetadata( + const std::string& name, const inference::DataType& dt, + const std::vector& shape); + + // The name of the input tensor. There is no mutable operator for + // the name because it is used in a InferenceRequest map and a + // mutable method would allow it to get out-of-sync. + const std::string& Name() const { return name_; } + + // Data type of the input tensor. + inference::DataType DType() const { return datatype_; } + + // The original shape of the input tensor. + const std::vector& OriginalShape() const + { + return original_shape_; + } + + // The shape of the input tensor after normalization. This shape + // is the original shape modified as required/expected by + // inference processing. + const std::vector& Shape() const { return shape_; } + std::vector* MutableShape() { return &shape_; } + + // FIXME. Should not need these functions. All shapes kept here + // should include the batch dimension instead of breaking the same + // into batch + shape. + const std::vector& ShapeWithBatchDim() const + { + return shape_with_batch_dim_; + } + std::vector* MutableShapeWithBatchDim() + { + return &shape_with_batch_dim_; + } + + // Return true if host-specific data was added for this input + bool HasHostPolicySpecificData() const + { + return has_host_policy_specific_data_; + } + + // Whether or not the input is a tensorrt shape tensor + bool IsShapeTensor() const { return is_shape_tensor_; } + + // Set the input to be treated as a shape tensor. + Status SetIsShapeTensor(const bool is_shape_tensor); + + // The data for this input. + const std::shared_ptr& Data() const { return data_; } + + // The data for this input for a specific device + const std::shared_ptr& Data( + const std::string& host_policy_name) const; + + // Return all host policy data set for this input + const std::map>& HostPolicyData() const + { + return host_policy_data_map_; + } + + // Set the data for this input. Error if input already has some + // data. + Status SetData(const std::shared_ptr& data); + + // Set the data associated with the host policy for this input. + // Return error if input already has some data. + Status SetData( + const std::string& host_policy_name, + const std::shared_ptr& data); + + // Append a new buffer of data to this input. + Status AppendData( + const void* base, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id); + + Status AppendDataWithHostPolicy( + const void* base, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id, const char* host_policy_name); + + Status AppendDataWithBufferAttributes( + const void* base, BufferAttributes* buffer_attributes); + + // Prepend a new buffer of data to this input. + Status PrependData( + const void* base, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id); + + // Remove all existing data for the input. + Status RemoveAllData(); + + // Get the number of buffers containing the input tensor data. + size_t DataBufferCount() const { return data_->BufferCount(); } + + // Get the number of buffers containing the input tensor data with + // host policy. If there are no buffers corresponding to the specific + // host policy, the number of buffers in the fallback input data is + // returned. + size_t DataBufferCountForHostPolicy( + const std::string& host_policy_name) const; + + // Get the 'idx' buffer containing a contiguous chunk of bytes for + // the input. Return error is 'idx' refers to a buffer that does + // not exist. Return a pointer to the chunk in 'base' and the + // size of the chunk in 'byte_size'. 'memory_type' acts as + // both input and output. On input 'memory_type' is the buffer + // memory type preferred by the function caller. On return + // 'memory_type' gives the actual memory type of the chunk pointed + // to by 'base'. 'memory_type_id' acts as both input and + // output. On input 'memory_type_id' is the buffer memory type id + // preferred by the function caller. On return 'memory_type_id' + // gives the actual memory type id of the chunk pointed to by + // 'base'. + Status DataBuffer( + const size_t idx, const void** base, size_t* byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id) const; + + // Get the buffer attributes associated with 'idx' buffer. + Status DataBufferAttributes( + const size_t idx, const void** base, + BufferAttributes** buffer_attributes) const; + + // Get the 'idx' buffer containing a contiguous chunk of bytes for + // the input. Return error is 'idx' refers to a buffer that does + // not exist. Return a pointer to the chunk in 'base' and the + // size of the chunk in 'byte_size'. 'memory_type' acts as + // both input and output. On input 'memory_type' is the buffer + // memory type preferred by the function caller. On return + // 'memory_type' gives the actual memory type of the chunk pointed + // to by 'base'. 'memory_type_id' acts as both input and + // output. On input 'memory_type_id' is the buffer memory type id + // preferred by the function caller. On return 'memory_type_id' + // gives the actual memory type id of the chunk pointed to by + // 'base'. + Status DataBufferForHostPolicy( + const size_t idx, const void** base, size_t* byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id, + const std::string& host_policy_name) const; + + private: + DISALLOW_COPY_AND_ASSIGN(Input); + friend std::ostream& operator<<( + std::ostream& out, const InferenceRequest::Input& input); + + std::string name_; + inference::DataType datatype_; + std::vector original_shape_; + std::vector shape_; + std::vector shape_with_batch_dim_; + bool is_shape_tensor_; + std::shared_ptr data_; + + bool has_host_policy_specific_data_; + // A map of host policy to input data memory + std::map> host_policy_data_map_; + }; + + // Sequence ID can be either a 64 bit integer or a string. + // This class implements the SequenceId type + class SequenceId { + public: + enum class DataType { UINT64, STRING }; + + SequenceId(); + SequenceId(const std::string& sequence_label); + SequenceId(uint64_t sequence_index); + SequenceId& operator=(const SequenceId& rhs) = default; + SequenceId& operator=(const std::string& rhs); + SequenceId& operator=(const uint64_t rhs); + + // Functions that help determine exact type of sequence Id + DataType Type() const { return id_type_; } + bool InSequence() const + { + return ((sequence_label_ != "") || (sequence_index_ != 0)); + } + + // Get the value of the SequenceId based on the type + const std::string& StringValue() const { return sequence_label_; } + uint64_t UnsignedIntValue() const { return sequence_index_; } + + private: + friend std::ostream& operator<<( + std::ostream& out, const InferenceRequest::SequenceId& correlation_id); + friend bool operator==(const SequenceId lhs, const SequenceId rhs); + friend bool operator!=(const SequenceId lhs, const SequenceId rhs); + + std::string sequence_label_; + uint64_t sequence_index_; + DataType id_type_; + }; + + // InferenceRequest + // + // The two constructors are identical except one takes model as a + // shared pointer and the other as a raw pointer. The shared pointer + // version is the primary one and acts to keep the model alive as + // long as the request is in flight. The raw pointer version is used + // only for cases where the model itself is issuing a request + // (e.g. warmup) and no shared pointer version of the model exists + // (because we aren't using shared_from_this). + InferenceRequest( + const std::shared_ptr& model, + const int64_t requested_model_version); + + InferenceRequest(Model* model, const int64_t requested_model_version); + + const std::string& ModelName() const; + int64_t RequestedModelVersion() const { return requested_model_version_; } + int64_t ActualModelVersion() const; + + const std::string& Id() const { return id_; } + void SetId(const std::string& i) { id_ = i; } + // Return string for logging request ID + std::string LogRequest() const + { + std::string id = Id(); + if (id.empty()) { + id = ""; + } + return std::string("[request id: ") + id + "] "; + } + + // Flags for the request, union of TRITONSERVER_RequestFlag. + uint32_t Flags() const { return flags_; } + void SetFlags(uint32_t f) { flags_ = f; } + + const SequenceId& CorrelationId() const { return correlation_id_; } + void SetCorrelationId(const SequenceId& c) { correlation_id_ = c; } + + // The batch size of the request, as understood by Triton. A + // batch-size of 0 indicates that the model doesn't support batching + // in a way that Triton understands. Batch size is not set + // explicitly so there is no setter for it. It is set when the + // request is normalized. + uint32_t BatchSize() const { return batch_size_; } + + uint32_t Priority() const { return priority_; } + void SetPriority(uint32_t p); + + uint64_t TimeoutMicroseconds() const { return timeout_us_; } + void SetTimeoutMicroseconds(uint64_t t) { timeout_us_ = t; } + + uint64_t CacheKey() const { return cache_key_; } + // It is up to the user to update the cache_key_ if modifying any hashable + // fields of the request after cache_key_is_set_ has been set to true. + void SetCacheKey(uint64_t key) + { + cache_key_ = key; + cache_key_is_set_ = true; + } + bool CacheKeyIsSet() const { return cache_key_is_set_; } + +#ifdef TRITON_ENABLE_TRACING + const std::shared_ptr& Trace() const { return trace_; } + std::shared_ptr* MutableTrace() { return &trace_; } + void SetTrace(const std::shared_ptr& trace) + { + trace_ = trace; + response_factory_->SetTrace(trace); + } + void ReleaseTrace() + { + trace_ = nullptr; + response_factory_->ReleaseTrace(); + } + + Status TraceInputTensors( + TRITONSERVER_InferenceTraceActivity activity, const std::string& msg); +#endif // TRITON_ENABLE_TRACING + + // The original inputs are the inputs added to the request before + // the inference execution (that is before + // TRITONSERVER_ServerInferAsync is called). Once execution has + // started the original inputs should not be modified until + // execution completes (and those modifications will apply to the + // next inference execution). + Status MutableOriginalInput(const std::string& name, Input** input); + std::unordered_map* MutableOriginalInputs() + { + return &original_inputs_; + } + const std::unordered_map& OriginalInputs() const + { + return original_inputs_; + } + + // The override inputs are the inputs added to the request after + // inference execution has started (that is after + // TRITONSERVER_ServerInferAsync or equivalent is called). During + // inference processing, if Triton needs to change an original input + // it will add an override instead of changing the original. Triton + // will also use an override if it needs to add a new input to the + // request. Overrides are recorded as shared_ptr so that the same + // override can be used efficiently multiple times or even in + // multiple requests simultaneously. Must be careful not to modify + // an override input if it is being shared unless you want that + // change to be reflected in all requests that hold that override + // input. Override inputs within a specific request are not + // persisted across inference calls. + std::unordered_map>* + MutableOverrideInputs() + { + return &override_inputs_; + } + const std::unordered_map>& + OverrideInputs() const + { + return override_inputs_; + } + + // Get an input taking into account both original inputs and + // overrides. If an override input is available use it, otherwise + // use the original input. Accessing inputs via this method is not + // valid until after PrepareForInference is called. + Status ImmutableInput(const std::string& name, const Input** input) const; + const std::unordered_map& ImmutableInputs() const + { + return inputs_; + } + + // The original requested outputs are the requested outputs added to + // the request before the inference execution (that is before + // TRITONSERVER_ServerInferAsync is called). Once execution has + // started the original requested outputs should not be modified + // until execution completes (and those modifications will apply to + // the next inference execution). + const std::set& OriginalRequestedOutputs() const + { + return original_requested_outputs_; + } + + // Get the requested outputs that should be used during + // inference. Accessing outputs via this method is not valid until + // after PrepareForInference is called. + const std::set& ImmutableRequestedOutputs() const + { + return (requested_outputs_.empty()) ? original_requested_outputs_ + : requested_outputs_; + } + + // Get the response factory. + const std::shared_ptr& ResponseFactory() const + { + return response_factory_; + } + + // Add an original input to the request. If 'input' is non-null + // return a pointer to the newly added input. + Status AddOriginalInput( + const std::string& name, const inference::DataType datatype, + const int64_t* shape, const uint64_t dim_count, Input** input = nullptr); + Status AddOriginalInput( + const std::string& name, const inference::DataType datatype, + const std::vector& shape, Input** input = nullptr); + + // Add an original raw input to the request. If 'input' is non-null + // return a pointer to the newly added input. + Status AddRawInput(const std::string& name, Input** input = nullptr); + + // Remove a single original input or all inputs. + Status RemoveOriginalInput(const std::string& name); + Status RemoveAllOriginalInputs(); + + // Add an override input to the request. If 'input' is non-null + // return a pointer to the newly added input. + // FIXME passing batch size is special handling for backend API. + // For override input, the 'shape' is without batch dimension for + // backends that implemented w/o backend API (which need correct + // input.Shape()), but backend API uses input.ShapeWithBatchDim(). + Status AddOverrideInput( + const std::string& name, const inference::DataType datatype, + const int64_t batch_size, const std::vector& shape, + std::shared_ptr* input = nullptr); + + // Add an override input to the request. + Status AddOverrideInput(const std::shared_ptr& input); + + // Request an original requested output. + Status AddOriginalRequestedOutput(const std::string& name); + + // Remove a single original requested output or all requested + // outputs. + Status RemoveOriginalRequestedOutput(const std::string& name); + Status RemoveAllOriginalRequestedOutputs(); + + // Initialize the release callback for the request. + Status SetReleaseCallback( + TRITONSERVER_InferenceRequestReleaseFn_t release_fn, void* release_userp) + { + release_fn_ = release_fn; + release_userp_ = release_userp; + return Status::Success; + } + + // Initialize the response factory that is to be used with any + // responses produced for this request. + Status SetResponseCallback( + const ResponseAllocator* allocator, void* alloc_userp, + TRITONSERVER_InferenceResponseCompleteFn_t response_fn, + void* response_userp) + { + response_factory_.reset(new InferenceResponseFactory( + model_shared_, id_, allocator, alloc_userp, response_fn, response_userp, + response_delegator_)); + return Status::Success; + } + + // Returns the preferred memory type and memory type ID of the output buffer + // for the request. 'name' and 'byte_size' are optional and set to nullptr + // if not specified, if provided, they give the allocator more information. + // 'memory_type' and 'memory_type_id' are also used as input to provide types + // preferred by the caller. + // Status::Code::UNAVAILABLE will be returned if output properties are not + // available. + Status OutputBufferProperties( + const char* name, size_t* byte_size, TRITONSERVER_MemoryType* memory_type, + int64_t* memory_type_id); + + // Add a callback to be invoked on releasing the request object from Triton. + // Multile callbacks can be added by calling this function in order, + // and they will be invoked in reversed order. + Status AddInternalReleaseCallback(std::function&& callback) + { + release_callbacks_.emplace_back(std::move(callback)); + return Status::Success; + } + + // Add a delegator to be invoked on sending the responses of this request. + // The response will be passed to 'delegator' and 'delegator' must call the + // InferenceResponse::Send() to send the response. + Status SetResponseDelegator( + std::function&&, const uint32_t)>&& delegator) + { + response_delegator_ = std::move(delegator); + return response_factory_->SetResponseDelegator(response_delegator_); + } + + Status SetSequenceStates( + const std::shared_ptr& sequence_states) + { + sequence_states_ = sequence_states; + return Status::Success; + } + + Status LoadInputStates(); + + const std::shared_ptr& GetSequenceStates() const + { + return sequence_states_; + } + + // Prepare this request for inference. + Status PrepareForInference(); + + // Run this inference request using the model associated with the + // request. If Status::Success is returned then the call has taken + // ownership of the request object and so 'request' will be + // nullptr. If non-success is returned then the caller still retains + // ownership of 'request'. + static Status Run(std::unique_ptr& request); + + // Send an error response for this request. If 'status' is Success + // then no response is sent and the request is not released (even if + // 'release_request' is true). Because this is sending an error it + // is assumed that this is the last response for the request and so + // the FINAL flag is set in the response callback. If + // 'release_request' is true then the release callback is called for + // this request and ownership is given to the callback. Thus, if + // 'release_request' is true 'request' is returned as nullptr. + static void RespondIfError( + std::unique_ptr& request, const Status& status, + const bool release_request = false); + + // Send an error response to a set of 'requests'. If 'status' is + // Success then no responses are sent and the requests are not + // released (even if 'release_request' is true). Because this is + // sending an error it is assumed that this is the last response for + // the requests and so the FINAL flag is set in the response + // callbacks. If 'release_request' is true then the release callback + // is called for each request, and the request ownership is given to + // the callback. Thus, if 'release_request' is true 'requests' is + // returned with all nullptrs. + static void RespondIfError( + std::vector>& requests, + const Status& status, const bool release_requests = false); + + // Release the request. Call the release callback and transfer + // ownership of the request to the callback. On return 'request' is + // nullptr. + static void Release( + std::unique_ptr&& request, + const uint32_t release_flags); + + // Create a copy of 'from' suitable for use as a "null" request as + // required for the direct sequence batcher. The returned copy will + // contain only the minimum content required for a null request. + // The statistics of the copy will not be collected. + static InferenceRequest* CopyAsNull(const InferenceRequest& from); + + uint64_t QueueStartNs() const { return queue_start_ns_; } + uint64_t CaptureQueueStartNs() + { + queue_start_ns_ = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + return queue_start_ns_; + } + + uint64_t CacheLookupStartNs() const { return cache_lookup_start_ns_; } + uint64_t CaptureCacheLookupStartNs() + { + cache_lookup_start_ns_ = + std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + return cache_lookup_start_ns_; + } + + uint64_t CacheLookupEndNs() const { return cache_lookup_end_ns_; } + uint64_t CaptureCacheLookupEndNs() + { + cache_lookup_end_ns_ = + std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + return cache_lookup_end_ns_; + } + + uint64_t CacheInsertionStartNs() const { return cache_insertion_start_ns_; } + uint64_t CaptureCacheInsertionStartNs() + { + cache_insertion_start_ns_ = + std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + return cache_insertion_start_ns_; + } + + uint64_t CacheInsertionEndNs() const { return cache_insertion_end_ns_; } + uint64_t CaptureCacheInsertionEndNs() + { + cache_insertion_end_ns_ = + std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + return cache_insertion_end_ns_; + } + + uint64_t BatcherStartNs() const { return batcher_start_ns_; } + uint64_t CaptureBatcherStartNs() + { + batcher_start_ns_ = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + return batcher_start_ns_; + } + +#ifdef TRITON_ENABLE_STATS + uint64_t RequestStartNs() const { return request_start_ns_; } + uint64_t CaptureRequestStartNs() + { + request_start_ns_ = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + return request_start_ns_; + } + + // Report the statistics to stats collectors associated with the request. + // Duration and timestamps provide two granularities for stats collectors. + void ReportStatistics( + MetricModelReporter* metric_reporter, bool success, + const uint64_t compute_start_ns, const uint64_t compute_input_end_ns, + const uint64_t compute_output_start_ns, const uint64_t compute_end_ns); + + // Report the statistics to stats collectors associated with the request. + // Duration and timestamps provide two granularities for stats collectors. + void ReportStatisticsWithDuration( + MetricModelReporter* metric_reporter, bool success, + const uint64_t compute_start_ns, const uint64_t compute_input_duration_ns, + const uint64_t compute_infer_duration_ns, + const uint64_t compute_output_duration_ns); + + // Report the statistics to stats collectors associated with the request on + // response cache hits. + void ReportStatisticsCacheHit(MetricModelReporter* metric_reporter); + + // Report the statistics to stats collectors associated with the request on + // response cache misses and update request duration to include cache + // insertion time. + void ReportStatisticsCacheMiss(MetricModelReporter* metric_reporter); + + // Statistics for each request are aggregated into the corresponding + // model's statistics. Optionally this function may be used to + // add an additional aggregator where statistics are also aggregated. + void SetSecondaryStatsAggregator( + InferenceStatsAggregator* secondary_stats_aggregator) + { + secondary_stats_aggregator_ = secondary_stats_aggregator; + } + +#endif // TRITON_ENABLE_STATS + + private: + DISALLOW_COPY_AND_ASSIGN(InferenceRequest); + friend std::ostream& operator<<( + std::ostream& out, const InferenceRequest& request); + + Status Normalize(); + + // Has anything in the request potentially changed in a way that + // causes normalization to be required when preparing the request + // for inference. + bool needs_normalization_; + + // The model associated with this request. For most requests + // model_shared_ will be non-null and will act to keep the model + // alive as long as this request is live. In this case model_raw_ + // will be the raw pointer from the shared pointer. For cases where + // the model itself created the request (like running requests for + // warmup), model_shared_ will be nullptr, but model_raw_ will + // still be defined. Thus model_raw_ is always defined and should + // always to used to access the model. + std::shared_ptr model_shared_; + Model* model_raw_; + + // The model version as requested and based on version policy the + // specific version that is actually used for inference. + int64_t requested_model_version_; + int64_t actual_model_version_; + + std::string id_; + + uint32_t flags_; + SequenceId correlation_id_; + uint32_t batch_size_; + uint32_t priority_; + uint64_t timeout_us_; + uint64_t cache_key_ = 0; + // Helper to determine if request was successfully hashed + // and cache_key_ field is valid + bool cache_key_is_set_ = false; + + std::unordered_map original_inputs_; + std::unordered_map> override_inputs_; + std::unordered_map inputs_; + std::set original_requested_outputs_; + std::string raw_input_name_; + uint32_t raw_input_size_; + + // requested_outputs_ is to be used post-normalization. It will be + // empty unless it differs from original_requested_outputs_, so + // typically should access it through ImmutableRequestedOutputs. + std::set requested_outputs_; + + // The release function and user pointer for this request. + TRITONSERVER_InferenceRequestReleaseFn_t release_fn_; + void* release_userp_; + + // Additional release callbacks invoked before 'release_fn_'. + std::vector> release_callbacks_; + + // Delegator to be invoked on sending responses. + std::function&&, const uint32_t)> + response_delegator_; + + // The response factory associated with this request. + std::shared_ptr response_factory_; + + // Request timestamps. Queue start is needed for schedulers even + // when statistics are not being collected. + uint64_t queue_start_ns_; + + // Cache lookup start/end timestamps. Cache manages its own stats even + // when statistics are not being colleceted. + uint64_t cache_lookup_start_ns_; + uint64_t cache_lookup_end_ns_; + + // Cache insertion start/end timestamps. Cache manages its own stats even + // when statistics are not being colleceted. + uint64_t cache_insertion_start_ns_; + uint64_t cache_insertion_end_ns_; + + // Dedicated timestamp for batcher internal which can diverge from + // queue start timestamp to provide accurate queue time without affecting + // batcher functionalities. + uint64_t batcher_start_ns_; + + // Whether the stats of the request should be collected. + bool collect_stats_; + +#ifdef TRITON_ENABLE_STATS + uint64_t request_start_ns_; + InferenceStatsAggregator* secondary_stats_aggregator_ = nullptr; +#endif // TRITON_ENABLE_STATS + +#ifdef TRITON_ENABLE_TRACING + // Inference trace associated with this request. + std::shared_ptr trace_; +#endif // TRITON_ENABLE_TRACING + + // Sequence I/O states used for implicit state. + std::shared_ptr sequence_states_; +}; + +std::ostream& operator<<(std::ostream& out, const InferenceRequest& request); +std::ostream& operator<<( + std::ostream& out, const InferenceRequest::Input& input); +std::ostream& operator<<( + std::ostream& out, const InferenceRequest::SequenceId& sequence_id); +bool operator==( + const InferenceRequest::SequenceId lhs, + const InferenceRequest::SequenceId rhs); +}} // namespace triton::core + +namespace std { +using namespace triton::core; +template <> +class hash { + public: + size_t operator()(const InferenceRequest::SequenceId& sequence_id) const + { + if (sequence_id.Type() == InferenceRequest::SequenceId::DataType::STRING) { + return std::hash{}(sequence_id.StringValue()); + } + return std::hash{}(sequence_id.UnsignedIntValue()); + } +}; +} // namespace std diff --git a/3rdparty/core-r22.12/src/infer_response.cc b/3rdparty/core-r22.12/src/infer_response.cc new file mode 100644 index 0000000000000000000000000000000000000000..2a8f2af2ec617149c31daa4566f632dbfb3b9ca7 --- /dev/null +++ b/3rdparty/core-r22.12/src/infer_response.cc @@ -0,0 +1,431 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "infer_response.h" + +#include "model.h" +#include "model_config_utils.h" +#include "server.h" +#include "triton/common/logging.h" + +namespace triton { namespace core { + +// +// InferenceResponseFactory +// +Status +InferenceResponseFactory::CreateResponse( + std::unique_ptr* response) const +{ + response->reset(new InferenceResponse( + model_, id_, allocator_, alloc_userp_, response_fn_, response_userp_, + response_delegator_)); +#ifdef TRITON_ENABLE_TRACING + (*response)->SetTrace(trace_); +#endif // TRITON_ENABLE_TRACING + return Status::Success; +} + +Status +InferenceResponseFactory::SendFlags(const uint32_t flags) const +{ + if (response_delegator_ != nullptr) { + std::unique_ptr response( + new InferenceResponse(response_fn_, response_userp_)); + response_delegator_(std::move(response), flags); + } else { + void* userp = response_userp_; + response_fn_(nullptr /* response */, flags, userp); + } + return Status::Success; +} + +// +// InferenceResponse +// +InferenceResponse::InferenceResponse( + const std::shared_ptr& model, const std::string& id, + const ResponseAllocator* allocator, void* alloc_userp, + TRITONSERVER_InferenceResponseCompleteFn_t response_fn, + void* response_userp, + const std::function< + void(std::unique_ptr&&, const uint32_t)>& delegator) + : model_(model), id_(id), allocator_(allocator), alloc_userp_(alloc_userp), + response_fn_(response_fn), response_userp_(response_userp), + response_delegator_(delegator), null_response_(false) +{ + // If the allocator has a start_fn then invoke it. + TRITONSERVER_ResponseAllocatorStartFn_t start_fn = allocator_->StartFn(); + if (start_fn != nullptr) { + LOG_TRITONSERVER_ERROR( + start_fn( + reinterpret_cast( + const_cast(allocator_)), + alloc_userp_), + "response allocation start failed"); + } +} + +InferenceResponse::InferenceResponse( + TRITONSERVER_InferenceResponseCompleteFn_t response_fn, + void* response_userp) + : response_fn_(response_fn), response_userp_(response_userp), + null_response_(true) +{ +} + +const std::string& +InferenceResponse::ModelName() const +{ + static const std::string unknown(""); + return (model_ == nullptr) ? unknown : model_->Name(); +} + +int64_t +InferenceResponse::ActualModelVersion() const +{ + return (model_ == nullptr) ? -1 : model_->Version(); +} + +Status +InferenceResponse::AddParameter(const char* name, const char* value) +{ + parameters_.emplace_back(name, value); + return Status::Success; +} + +Status +InferenceResponse::AddParameter(const char* name, const int64_t value) +{ + parameters_.emplace_back(name, value); + return Status::Success; +} + +Status +InferenceResponse::AddParameter(const char* name, const bool value) +{ + parameters_.emplace_back(name, value); + return Status::Success; +} + +Status +InferenceResponse::AddOutput( + const std::string& name, const inference::DataType datatype, + const std::vector& shape, InferenceResponse::Output** output) +{ + outputs_.emplace_back(name, datatype, shape, allocator_, alloc_userp_); + + LOG_VERBOSE(1) << "add response output: " << outputs_.back(); + + if (model_ != nullptr) { + const inference::ModelOutput* output_config; + RETURN_IF_ERROR(model_->GetOutput(name, &output_config)); + if (output_config->has_reshape()) { + const bool has_batch_dim = (model_->Config().max_batch_size() > 0); + outputs_.back().Reshape(has_batch_dim, output_config); + } + } + + if (output != nullptr) { + *output = std::addressof(outputs_.back()); + } + + return Status::Success; +} + +Status +InferenceResponse::AddOutput( + const std::string& name, const inference::DataType datatype, + std::vector&& shape, InferenceResponse::Output** output) +{ + outputs_.emplace_back( + name, datatype, std::move(shape), allocator_, alloc_userp_); + + LOG_VERBOSE(1) << "add response output: " << outputs_.back(); + + if (model_ != nullptr) { + const inference::ModelOutput* output_config; + RETURN_IF_ERROR(model_->GetOutput(name, &output_config)); + if (output_config->has_reshape()) { + const bool has_batch_dim = (model_->Config().max_batch_size() > 0); + outputs_.back().Reshape(has_batch_dim, output_config); + } + } + + if (output != nullptr) { + *output = std::addressof(outputs_.back()); + } + + return Status::Success; +} + +Status +InferenceResponse::ClassificationLabel( + const InferenceResponse::Output& output, const uint32_t class_index, + const char** label) const +{ + const auto& label_provider = model_->GetLabelProvider(); + const std::string& l = label_provider->GetLabel(output.Name(), class_index); + if (l.empty()) { + *label = nullptr; + } else { + *label = l.c_str(); + } + + return Status::Success; +} + +Status +InferenceResponse::Send( + std::unique_ptr&& response, const uint32_t flags) +{ +#ifdef TRITON_ENABLE_TRACING + response->TraceOutputTensors( + TRITONSERVER_TRACE_TENSOR_BACKEND_OUTPUT, "InferenceResponse Send"); +#endif // TRITON_ENABLE_TRACING + + if (response->response_delegator_ != nullptr) { + auto ldelegator = std::move(response->response_delegator_); + ldelegator(std::move(response), flags); + return Status::Success; + } + void* userp = response->response_userp_; + if (response->null_response_) { + response->response_fn_(nullptr /* response */, flags, userp); + } else { + auto& response_fn = response->response_fn_; + response_fn( + reinterpret_cast(response.release()), + flags, userp); + } + return Status::Success; +} + +Status +InferenceResponse::SendWithStatus( + std::unique_ptr&& response, const uint32_t flags, + const Status& status) +{ + response->status_ = status; + return InferenceResponse::Send(std::move(response), flags); +} + +#ifdef TRITON_ENABLE_TRACING +Status +InferenceResponse::TraceOutputTensors( + TRITONSERVER_InferenceTraceActivity activity, const std::string& msg) +{ + const auto& outputs = this->Outputs(); + uint32_t output_count = outputs.size(); + + for (uint32_t idx = 0; idx < output_count; ++idx) { + const Output& output = outputs[idx]; + + // output data + const char* cname = output.Name().c_str(); + TRITONSERVER_DataType datatype = DataTypeToTriton(output.DType()); + const std::vector& oshape = output.Shape(); + const int64_t* shape = &oshape[0]; + uint64_t dim_count = oshape.size(); + const void* base; + size_t byte_size; + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id; + void* userp; + + Status status = output.DataBuffer( + &base, &byte_size, &memory_type, &memory_type_id, &userp); + if (!status.IsOk()) { + LOG_STATUS_ERROR( + status, + std::string(TRITONSERVER_InferenceTraceActivityString(activity)) + + ": " + msg + ": fail to get data buffer: " + status.Message()); + return status; + } + + INFER_TRACE_TENSOR_ACTIVITY( + this->trace_, activity, cname, datatype, base, byte_size, shape, + dim_count, memory_type, memory_type_id); + } + + return Status::Success; +} +#endif // TRITON_ENABLE_TRACING + +// +// InferenceResponse::Output +// +InferenceResponse::Output::~Output() +{ + Status status = ReleaseDataBuffer(); + if (!status.IsOk()) { + LOG_ERROR << "failed to release buffer for output '" << name_ + << "': " << status.AsString(); + } +} + +void +InferenceResponse::Output::Reshape( + const bool has_batch_dim, const inference::ModelOutput* output_config) +{ + std::deque variable_size_values; + + const int64_t batch_dim = + (has_batch_dim && (shape_.size() > 0)) ? shape_[0] : -1; + const size_t batch_dim_offset = (has_batch_dim) ? 1 : 0; + + const auto& from_shape = output_config->reshape().shape(); + const auto& to_shape = output_config->dims(); + for (int64_t idx = 0; idx < from_shape.size(); idx++) { + if (from_shape[idx] == -1) { + variable_size_values.push_back(shape_[idx + batch_dim_offset]); + } + } + + shape_.clear(); + if (batch_dim >= 0) { + shape_.push_back(batch_dim); + } + + for (const auto& dim : to_shape) { + if (dim == -1) { + shape_.push_back(variable_size_values.front()); + variable_size_values.pop_front(); + } else { + shape_.push_back(dim); + } + } +} + +Status +InferenceResponse::Output::DataBuffer( + const void** buffer, size_t* buffer_byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id, + void** userp) const +{ + *buffer = allocated_buffer_; + *buffer_byte_size = buffer_attributes_.ByteSize(); + *memory_type = buffer_attributes_.MemoryType(); + *memory_type_id = buffer_attributes_.MemoryTypeId(); + *userp = allocated_userp_; + return Status::Success; +} + +Status +InferenceResponse::Output::AllocateDataBuffer( + void** buffer, size_t buffer_byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id) +{ + if (allocated_buffer_ != nullptr) { + return Status( + Status::Code::ALREADY_EXISTS, + "allocated buffer for output '" + name_ + "' already exists"); + } + + TRITONSERVER_MemoryType actual_memory_type = *memory_type; + int64_t actual_memory_type_id = *memory_type_id; + void* alloc_buffer_userp = nullptr; + + RETURN_IF_TRITONSERVER_ERROR(allocator_->AllocFn()( + reinterpret_cast( + const_cast(allocator_)), + name_.c_str(), buffer_byte_size, *memory_type, *memory_type_id, + alloc_userp_, buffer, &alloc_buffer_userp, &actual_memory_type, + &actual_memory_type_id)); + + // Only call the buffer attributes API if it is set. + if (allocator_->BufferAttributesFn() != nullptr) { + RETURN_IF_TRITONSERVER_ERROR(allocator_->BufferAttributesFn()( + reinterpret_cast( + const_cast(allocator_)), + name_.c_str(), + reinterpret_cast(&buffer_attributes_), + alloc_userp_, alloc_buffer_userp)); + } + + allocated_buffer_ = *buffer; + buffer_attributes_.SetByteSize(buffer_byte_size); + buffer_attributes_.SetMemoryType(actual_memory_type); + buffer_attributes_.SetMemoryTypeId(actual_memory_type_id); + + allocated_userp_ = alloc_buffer_userp; + *memory_type = actual_memory_type; + *memory_type_id = actual_memory_type_id; + + return Status::Success; +} + +Status +InferenceResponse::Output::ReleaseDataBuffer() +{ + TRITONSERVER_Error* err = nullptr; + + if (allocated_buffer_ != nullptr) { + err = allocator_->ReleaseFn()( + reinterpret_cast( + const_cast(allocator_)), + allocated_buffer_, allocated_userp_, buffer_attributes_.ByteSize(), + buffer_attributes_.MemoryType(), buffer_attributes_.MemoryTypeId()); + } + + allocated_buffer_ = nullptr; + buffer_attributes_.SetByteSize(0); + buffer_attributes_.SetMemoryType(TRITONSERVER_MEMORY_CPU); + buffer_attributes_.SetMemoryTypeId(0); + allocated_userp_ = nullptr; + + RETURN_IF_TRITONSERVER_ERROR(err); + + return Status::Success; +} + +std::ostream& +operator<<(std::ostream& out, const InferenceResponse& response) +{ + out << "[0x" << std::addressof(response) << "] " + << "response id: " << response.Id() << ", model: " << response.ModelName() + << ", actual version: " << response.ActualModelVersion() << std::endl; + + out << "status:" << response.ResponseStatus().AsString() << std::endl; + + out << "outputs:" << std::endl; + for (const auto& output : response.Outputs()) { + out << "[0x" << std::addressof(output) << "] " << output << std::endl; + } + + return out; +} + +std::ostream& +operator<<(std::ostream& out, const InferenceResponse::Output& output) +{ + out << "output: " << output.Name() + << ", type: " << triton::common::DataTypeToProtocolString(output.DType()) + << ", shape: " << triton::common::DimsListToString(output.Shape()); + return out; +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/infer_response.h b/3rdparty/core-r22.12/src/infer_response.h new file mode 100644 index 0000000000000000000000000000000000000000..783641558db643388bd06d5db665ceb4d4395980 --- /dev/null +++ b/3rdparty/core-r22.12/src/infer_response.h @@ -0,0 +1,351 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include +#include "buffer_attributes.h" +#include "constants.h" +#include "infer_parameter.h" +#include "infer_trace.h" +#include "response_allocator.h" +#include "status.h" +#include "triton/common/model_config.h" +#include "tritonserver_apis.h" + +namespace triton { namespace core { + +class Model; +class InferenceResponse; +// +// An inference response factory. +// +class InferenceResponseFactory { + public: + InferenceResponseFactory() = default; + + InferenceResponseFactory( + const std::shared_ptr& model, const std::string& id, + const ResponseAllocator* allocator, void* alloc_userp, + TRITONSERVER_InferenceResponseCompleteFn_t response_fn, + void* response_userp, + const std::function&&, const uint32_t)>& delegator) + : model_(model), id_(id), allocator_(allocator), + alloc_userp_(alloc_userp), response_fn_(response_fn), + response_userp_(response_userp), response_delegator_(delegator) + { + } + + const ResponseAllocator* Allocator() { return allocator_; } + void* AllocatorUserp() { return alloc_userp_; } + + Status SetResponseDelegator( + const std::function&&, const uint32_t)>& delegator) + { + response_delegator_ = delegator; + return Status::Success; + } + + // Create a new response. + Status CreateResponse(std::unique_ptr* response) const; + + // Send a "null" response with 'flags'. + Status SendFlags(const uint32_t flags) const; + +#ifdef TRITON_ENABLE_TRACING + const std::shared_ptr& Trace() const { return trace_; } + void SetTrace(const std::shared_ptr& trace) + { + trace_ = trace; + } + void ReleaseTrace() { trace_ = nullptr; } +#endif // TRITON_ENABLE_TRACING + + private: + // The model associated with this factory. For normal + // requests/responses this will always be defined and acts to keep + // the model loaded as long as this factory is live. It may be + // nullptr for cases where the model itself created the request + // (like running requests for warmup) and so must protect any uses + // to handle the nullptr case. + std::shared_ptr model_; + + // The ID of the corresponding request that should be included in every + // response. This is a property that can be optionally provided by the user. + std::string id_; + + // The response allocator and user pointer. The 'allocator_' is a + // raw pointer because it is owned by the client, and the client is + // responsible for ensuring that the lifetime of the allocator + // extends longer that any request or response that depend on the + // allocator. + const ResponseAllocator* allocator_; + void* alloc_userp_; + + // The response callback function and user pointer. + TRITONSERVER_InferenceResponseCompleteFn_t response_fn_; + void* response_userp_; + + // Delegator to be invoked on sending responses. + std::function&&, const uint32_t)> + response_delegator_; + + +#ifdef TRITON_ENABLE_TRACING + // Inference trace associated with this response. + std::shared_ptr trace_; +#endif // TRITON_ENABLE_TRACING +}; + +// +// An inference response. +// +class InferenceResponse { + public: + // Output tensor + class Output { + public: + Output( + const std::string& name, const inference::DataType datatype, + const std::vector& shape, const ResponseAllocator* allocator, + void* alloc_userp) + : name_(name), datatype_(datatype), shape_(shape), + allocator_(allocator), alloc_userp_(alloc_userp), + allocated_buffer_(nullptr) + { + } + Output( + const std::string& name, const inference::DataType datatype, + std::vector&& shape, const ResponseAllocator* allocator, + void* alloc_userp) + : name_(name), datatype_(datatype), shape_(std::move(shape)), + allocator_(allocator), alloc_userp_(alloc_userp), + allocated_buffer_(nullptr) + { + } + + ~Output(); + + // The name of the output tensor. + const std::string& Name() const { return name_; } + + // Data type of the output tensor. + inference::DataType DType() const { return datatype_; } + + // The shape of the output tensor. + const std::vector& Shape() const { return shape_; } + + BufferAttributes* GetBufferAttributes() { return &buffer_attributes_; } + + // Reshape the output tensor. This function must only be called + // for outputs that have respace specified in the model + // configuration. + void Reshape( + const bool has_batch_dim, const inference::ModelOutput* output_config); + + // Get information about the buffer allocated for this output + // tensor's data. If no buffer is allocated 'buffer' will return + // nullptr and the other returned values will be undefined. + Status DataBuffer( + const void** buffer, size_t* buffer_byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id, + void** userp) const; + + // Allocate the buffer that should be used for this output + // tensor's data. 'buffer' must return a buffer of size + // 'buffer_byte_size'. 'memory_type' acts as both input and + // output. On input gives the buffer memory type preferred by the + // caller and on return holds the actual memory type of + // 'buffer'. 'memory_type_id' acts as both input and output. On + // input gives the buffer memory type id preferred by the caller + // and returns the actual memory type id of 'buffer'. Only a + // single buffer may be allocated for the output at any time, so + // multiple calls to AllocateDataBuffer without intervening + // ReleaseDataBuffer call will result in an error. + Status AllocateDataBuffer( + void** buffer, const size_t buffer_byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id); + + // Release the buffer that was previously allocated by + // AllocateDataBuffer(). Do nothing if AllocateDataBuffer() has + // not been called. + Status ReleaseDataBuffer(); + + private: + DISALLOW_COPY_AND_ASSIGN(Output); + friend std::ostream& operator<<( + std::ostream& out, const InferenceResponse::Output& output); + + std::string name_; + inference::DataType datatype_; + std::vector shape_; + + // The response allocator and user pointer. + const ResponseAllocator* allocator_; + void* alloc_userp_; + + // Information about the buffer allocated by + // AllocateDataBuffer(). This information is needed by + // DataBuffer() and ReleaseDataBuffer(). + void* allocated_buffer_; + BufferAttributes buffer_attributes_; + void* allocated_userp_; + }; + + // InferenceResponse + InferenceResponse( + const std::shared_ptr& model, const std::string& id, + const ResponseAllocator* allocator, void* alloc_userp, + TRITONSERVER_InferenceResponseCompleteFn_t response_fn, + void* response_userp, + const std::function&&, const uint32_t)>& delegator); + + // "null" InferenceResponse is a special instance of InferenceResponse which + // contains minimal information for calling InferenceResponse::Send, + // InferenceResponse::NullResponse. nullptr will be passed as response in + // 'response_fn'. + InferenceResponse( + TRITONSERVER_InferenceResponseCompleteFn_t response_fn, + void* response_userp); + + const std::string& Id() const { return id_; } + const std::string& ModelName() const; + int64_t ActualModelVersion() const; + const Status& ResponseStatus() const { return status_; } + + // The response parameters. + const std::deque& Parameters() const + { + return parameters_; + } + + // Add an parameter to the response. + Status AddParameter(const char* name, const char* value); + Status AddParameter(const char* name, const int64_t value); + Status AddParameter(const char* name, const bool value); + + // The response outputs. + const std::deque& Outputs() const { return outputs_; } + + // Add an output to the response. If 'output' is non-null + // return a pointer to the newly added output. + Status AddOutput( + const std::string& name, const inference::DataType datatype, + const std::vector& shape, Output** output = nullptr); + Status AddOutput( + const std::string& name, const inference::DataType datatype, + std::vector&& shape, Output** output = nullptr); + + // Get the classification label associated with an output. Return + // 'label' == nullptr if no label. + Status ClassificationLabel( + const Output& output, const uint32_t class_index, + const char** label) const; + + // Send the response with success status. Calling this function + // releases ownership of the response object and gives it to the + // callback function. + static Status Send( + std::unique_ptr&& response, const uint32_t flags); + + // Send the response with explicit status. Calling this function + // releases ownership of the response object and gives it to the + // callback function. + static Status SendWithStatus( + std::unique_ptr&& response, const uint32_t flags, + const Status& status); + +#ifdef TRITON_ENABLE_TRACING + const std::shared_ptr& Trace() const { return trace_; } + void SetTrace(const std::shared_ptr& trace) + { + trace_ = trace; + } + void ReleaseTrace() { trace_ = nullptr; } +#endif // TRITON_ENABLE_TRACING + + private: + DISALLOW_COPY_AND_ASSIGN(InferenceResponse); + friend std::ostream& operator<<( + std::ostream& out, const InferenceResponse& response); + +#ifdef TRITON_ENABLE_TRACING + Status TraceOutputTensors( + TRITONSERVER_InferenceTraceActivity activity, const std::string& msg); +#endif // TRITON_ENABLE_TRACING + + // The model associated with this factory. For normal + // requests/responses this will always be defined and acts to keep + // the model loaded as long as this factory is live. It may be + // nullptr for cases where the model itself created the request + // (like running requests for warmup) and so must protect any uses + // to handle the nullptr case. + std::shared_ptr model_; + + // The ID of the corresponding request that should be included in + // every response. + std::string id_; + + // Error status for the response. + Status status_; + + // The parameters of the response. Use a deque so that there is no + // reallocation. + std::deque parameters_; + + // The result tensors. Use a deque so that there is no reallocation. + std::deque outputs_; + + // The response allocator and user pointer. + const ResponseAllocator* allocator_; + void* alloc_userp_; + + // The response callback function and user pointer. + TRITONSERVER_InferenceResponseCompleteFn_t response_fn_; + void* response_userp_; + + // Delegator to be invoked on sending responses. + std::function&&, const uint32_t)> + response_delegator_; + + bool null_response_; + +#ifdef TRITON_ENABLE_TRACING + // Inference trace associated with this response. + std::shared_ptr trace_; +#endif // TRITON_ENABLE_TRACING +}; + +std::ostream& operator<<(std::ostream& out, const InferenceResponse& response); +std::ostream& operator<<( + std::ostream& out, const InferenceResponse::Output& output); + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/infer_stats.cc b/3rdparty/core-r22.12/src/infer_stats.cc new file mode 100644 index 0000000000000000000000000000000000000000..1d33a1898c15896667e3e9603d1928397c32e166 --- /dev/null +++ b/3rdparty/core-r22.12/src/infer_stats.cc @@ -0,0 +1,241 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "infer_stats.h" + +#include +#include "metric_model_reporter.h" +#include "metrics.h" +#include "triton/common/logging.h" + +namespace triton { namespace core { + +#ifdef TRITON_ENABLE_STATS + +void +InferenceStatsAggregator::UpdateFailure( + MetricModelReporter* metric_reporter, const uint64_t request_start_ns, + const uint64_t request_end_ns) +{ + std::lock_guard lock(mu_); + + infer_stats_.failure_count_++; + infer_stats_.failure_duration_ns_ += (request_end_ns - request_start_ns); + +#ifdef TRITON_ENABLE_METRICS + if (metric_reporter != nullptr) { + metric_reporter->MetricInferenceFailure().Increment(1); + } +#endif // TRITON_ENABLE_METRICS +} + +void +InferenceStatsAggregator::UpdateSuccess( + MetricModelReporter* metric_reporter, const size_t batch_size, + const uint64_t request_start_ns, const uint64_t queue_start_ns, + const uint64_t compute_start_ns, const uint64_t compute_input_end_ns, + const uint64_t compute_output_start_ns, const uint64_t compute_end_ns, + const uint64_t request_end_ns) +{ + const uint64_t compute_input_duration_ns = + compute_input_end_ns - compute_start_ns; + const uint64_t compute_infer_duration_ns = + compute_output_start_ns - compute_input_end_ns; + const uint64_t compute_output_duration_ns = + compute_end_ns - compute_output_start_ns; + UpdateSuccessWithDuration( + metric_reporter, batch_size, request_start_ns, queue_start_ns, + compute_start_ns, request_end_ns, compute_input_duration_ns, + compute_infer_duration_ns, compute_output_duration_ns); +} + +void +InferenceStatsAggregator::UpdateSuccessWithDuration( + MetricModelReporter* metric_reporter, const size_t batch_size, + const uint64_t request_start_ns, const uint64_t queue_start_ns, + const uint64_t compute_start_ns, const uint64_t request_end_ns, + const uint64_t compute_input_duration_ns, + const uint64_t compute_infer_duration_ns, + const uint64_t compute_output_duration_ns) +{ + const uint64_t request_duration_ns = request_end_ns - request_start_ns; + const uint64_t queue_duration_ns = compute_start_ns - queue_start_ns; + + std::lock_guard lock(mu_); + + inference_count_ += batch_size; + + infer_stats_.success_count_++; + infer_stats_.request_duration_ns_ += request_duration_ns; + infer_stats_.queue_duration_ns_ += queue_duration_ns; + infer_stats_.compute_input_duration_ns_ += compute_input_duration_ns; + infer_stats_.compute_infer_duration_ns_ += compute_infer_duration_ns; + infer_stats_.compute_output_duration_ns_ += compute_output_duration_ns; + +#ifdef TRITON_ENABLE_METRICS + if (metric_reporter != nullptr) { + metric_reporter->MetricInferenceSuccess().Increment(1); + metric_reporter->MetricInferenceCount().Increment(batch_size); + metric_reporter->MetricInferenceRequestDuration().Increment( + request_duration_ns / 1000); + metric_reporter->MetricInferenceQueueDuration().Increment( + queue_duration_ns / 1000); + metric_reporter->MetricInferenceComputeInputDuration().Increment( + compute_input_duration_ns / 1000); + metric_reporter->MetricInferenceComputeInferDuration().Increment( + compute_infer_duration_ns / 1000); + metric_reporter->MetricInferenceComputeOutputDuration().Increment( + compute_output_duration_ns / 1000); + } +#endif // TRITON_ENABLE_METRICS +} + +// Currently cache hits will not go to the inference backend where metrics +// are typically updated, so this method allows us to update relevant metrics +// from a metric reporter rather than going through the backend. +void +InferenceStatsAggregator::UpdateSuccessCacheHit( + MetricModelReporter* metric_reporter, const size_t batch_size, + const uint64_t request_start_ns, const uint64_t queue_start_ns, + const uint64_t cache_lookup_start_ns, const uint64_t request_end_ns, + const uint64_t cache_hit_lookup_duration_ns) +{ + const uint64_t request_duration_ns = request_end_ns - request_start_ns; + const uint64_t queue_duration_ns = cache_lookup_start_ns - queue_start_ns; + + std::lock_guard lock(mu_); + + infer_stats_.success_count_++; + infer_stats_.request_duration_ns_ += request_duration_ns; + infer_stats_.queue_duration_ns_ += queue_duration_ns; + infer_stats_.cache_hit_count_++; + infer_stats_.cache_hit_lookup_duration_ns_ += cache_hit_lookup_duration_ns; + +#ifdef TRITON_ENABLE_METRICS + if (metric_reporter != nullptr) { + metric_reporter->MetricInferenceSuccess().Increment(1); + metric_reporter->MetricInferenceRequestDuration().Increment( + request_duration_ns / 1000); + metric_reporter->MetricInferenceQueueDuration().Increment( + queue_duration_ns / 1000); + metric_reporter->MetricCacheHitCount().Increment(1); + metric_reporter->MetricCacheHitLookupDuration().Increment( + cache_hit_lookup_duration_ns / 1000); + } +#endif // TRITON_ENABLE_METRICS +} + +// Cache misses will go to the inference backend where metrics are typically +// updated, but cache insertion happens after the inference backend finishes. +// So we use this method to update cache miss stats and adjust the request +// duration to include cache insertion time. +void +InferenceStatsAggregator::UpdateSuccessCacheMiss( + MetricModelReporter* metric_reporter, + const uint64_t cache_miss_lookup_duration_ns, + const uint64_t cache_miss_insertion_duration_ns) +{ + std::lock_guard lock(mu_); + + const uint64_t cache_miss_duration_ns = + cache_miss_lookup_duration_ns + cache_miss_insertion_duration_ns; + infer_stats_.request_duration_ns_ += cache_miss_duration_ns; + infer_stats_.cache_miss_count_++; + infer_stats_.cache_miss_lookup_duration_ns_ += cache_miss_lookup_duration_ns; + infer_stats_.cache_miss_insertion_duration_ns_ += + cache_miss_insertion_duration_ns; + +#ifdef TRITON_ENABLE_METRICS + if (metric_reporter != nullptr) { + // Add cache insertion time to request duration since insertion + // happens after inference backend sets the request duration, and + // cache lookup time was already included before the inference backend + // was called + metric_reporter->MetricInferenceRequestDuration().Increment( + cache_miss_duration_ns / 1000); + metric_reporter->MetricCacheMissCount().Increment(1); + metric_reporter->MetricCacheMissLookupDuration().Increment( + cache_miss_lookup_duration_ns / 1000); + metric_reporter->MetricCacheMissInsertionDuration().Increment( + cache_miss_insertion_duration_ns / 1000); + } +#endif // TRITON_ENABLE_METRICS +} + +void +InferenceStatsAggregator::UpdateInferBatchStats( + MetricModelReporter* metric_reporter, const size_t batch_size, + const uint64_t compute_start_ns, const uint64_t compute_input_end_ns, + const uint64_t compute_output_start_ns, const uint64_t compute_end_ns) +{ + auto compute_input_duration_ns = (compute_input_end_ns - compute_start_ns); + auto compute_infer_duration_ns = + (compute_output_start_ns - compute_input_end_ns); + auto compute_output_duration_ns = (compute_end_ns - compute_output_start_ns); + UpdateInferBatchStatsWithDuration( + metric_reporter, batch_size, compute_input_duration_ns, + compute_infer_duration_ns, compute_output_duration_ns); +} + +void +InferenceStatsAggregator::UpdateInferBatchStatsWithDuration( + MetricModelReporter* metric_reporter, size_t batch_size, + const uint64_t compute_input_duration_ns, + const uint64_t compute_infer_duration_ns, + const uint64_t compute_output_duration_ns) +{ + uint64_t inference_ms = + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + + std::lock_guard lock(mu_); + + if (inference_ms > last_inference_ms_) { + last_inference_ms_ = inference_ms; + } + + execution_count_++; + + auto it = batch_stats_.find(batch_size); + if (it == batch_stats_.end()) { + it = batch_stats_.emplace(batch_size, InferBatchStats()).first; + } + it->second.count_++; + it->second.compute_input_duration_ns_ += compute_input_duration_ns; + it->second.compute_infer_duration_ns_ += compute_infer_duration_ns; + it->second.compute_output_duration_ns_ += compute_output_duration_ns; + +#ifdef TRITON_ENABLE_METRICS + if (metric_reporter != nullptr) { + metric_reporter->MetricInferenceExecutionCount().Increment(1); + } +#endif // TRITON_ENABLE_METRICS +} + +#endif // TRITON_ENABLE_STATS + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/infer_stats.h b/3rdparty/core-r22.12/src/infer_stats.h new file mode 100644 index 0000000000000000000000000000000000000000..b5e3be8429dc20b986db0a52f0760f15b2b96bfb --- /dev/null +++ b/3rdparty/core-r22.12/src/infer_stats.h @@ -0,0 +1,190 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include +#include +#include "constants.h" +#include "infer_response.h" +#include "status.h" +#include "tritonserver_apis.h" + +namespace triton { namespace core { + +class MetricModelReporter; + + +// +// InferenceStatsAggregator +// +// A statistics aggregator. +// +class InferenceStatsAggregator { +#ifdef TRITON_ENABLE_STATS + public: + struct InferStats { + InferStats() + : failure_count_(0), failure_duration_ns_(0), success_count_(0), + request_duration_ns_(0), queue_duration_ns_(0), + compute_input_duration_ns_(0), compute_infer_duration_ns_(0), + compute_output_duration_ns_(0), cache_hit_count_(0), + cache_hit_lookup_duration_ns_(0), cache_miss_count_(0), + cache_miss_lookup_duration_ns_(0), + cache_miss_insertion_duration_ns_(0) + { + } + uint64_t failure_count_; + uint64_t failure_duration_ns_; + + uint64_t success_count_; + uint64_t request_duration_ns_; + uint64_t queue_duration_ns_; + uint64_t compute_input_duration_ns_; + uint64_t compute_infer_duration_ns_; + uint64_t compute_output_duration_ns_; + + // Cache hit stats + uint64_t cache_hit_count_; + uint64_t cache_hit_lookup_duration_ns_; + // Cache miss stats + uint64_t cache_miss_count_; + uint64_t cache_miss_lookup_duration_ns_; + uint64_t cache_miss_insertion_duration_ns_; + }; + + struct InferBatchStats { + InferBatchStats() + : count_(0), compute_input_duration_ns_(0), + compute_infer_duration_ns_(0), compute_output_duration_ns_(0) + { + } + uint64_t count_; + uint64_t compute_input_duration_ns_; + uint64_t compute_infer_duration_ns_; + uint64_t compute_output_duration_ns_; + }; + + // Create an aggregator for model statistics + InferenceStatsAggregator() + : last_inference_ms_(0), inference_count_(0), execution_count_(0) + { + } + + uint64_t LastInferenceMs() const { return last_inference_ms_; } + uint64_t InferenceCount() const { return inference_count_; } + uint64_t ExecutionCount() const { return execution_count_; } + const InferStats& ImmutableInferStats() const { return infer_stats_; } + const std::map& ImmutableInferBatchStats() const + { + return batch_stats_; + } + + // Add durations to Infer stats for a failed inference request. + void UpdateFailure( + MetricModelReporter* metric_reporter, const uint64_t request_start_ns, + const uint64_t request_end_ns); + + // Add durations to infer stats for a successful inference request. + void UpdateSuccess( + MetricModelReporter* metric_reporter, const size_t batch_size, + const uint64_t request_start_ns, const uint64_t queue_start_ns, + const uint64_t compute_start_ns, const uint64_t compute_input_end_ns, + const uint64_t compute_output_start_ns, const uint64_t compute_end_ns, + const uint64_t request_end_ns); + + // Add durations to infer stats for a successful inference request. + void UpdateSuccessWithDuration( + MetricModelReporter* metric_reporter, const size_t batch_size, + const uint64_t request_start_ns, const uint64_t queue_start_ns, + const uint64_t compute_start_ns, const uint64_t request_end_ns, + const uint64_t compute_input_duration_ns, + const uint64_t compute_infer_duration_ns, + const uint64_t compute_output_duration_ns); + + // Add durations to infer stats for a successful cached response. + void UpdateSuccessCacheHit( + MetricModelReporter* metric_reporter, const size_t batch_size, + const uint64_t request_start_ns, const uint64_t queue_start_ns, + const uint64_t cache_lookup_start_ns, const uint64_t request_end_ns, + const uint64_t cache_hit_lookup_duration_ns); + + // Add durations to infer stats for a cache miss and update request duration + // to account for cache insertion after backend computes the response. + void UpdateSuccessCacheMiss( + MetricModelReporter* metric_reporter, + const uint64_t cache_miss_lookup_duration_ns, + const uint64_t cache_miss_insertion_duration_ns); + + // Add durations to batch infer stats for a batch execution. + // 'success_request_count' is the number of sucess requests in the + // batch that have infer_stats attached. + void UpdateInferBatchStats( + MetricModelReporter* metric_reporter, const size_t batch_size, + const uint64_t compute_start_ns, const uint64_t compute_input_end_ns, + const uint64_t compute_output_start_ns, const uint64_t compute_end_ns); + + // Add durations to batch infer stats for a batch execution. + // 'success_request_count' is the number of sucess requests in the + // batch that have infer_stats attached. + void UpdateInferBatchStatsWithDuration( + MetricModelReporter* metric_reporter, size_t batch_size, + const uint64_t compute_input_duration_ns, + const uint64_t compute_infer_duration_ns, + const uint64_t compute_output_duration_ns); + + private: + std::mutex mu_; + uint64_t last_inference_ms_; + uint64_t inference_count_; + uint64_t execution_count_; + InferStats infer_stats_; + std::map batch_stats_; +#endif // TRITON_ENABLE_STATS +}; + + +// +// Macros to set infer stats. +// +#ifdef TRITON_ENABLE_STATS +#define INFER_STATS_SET_TIMESTAMP(TS_NS) \ + { \ + TS_NS = std::chrono::duration_cast( \ + std::chrono::steady_clock::now().time_since_epoch()) \ + .count(); \ + } +#define INFER_STATS_DECL_TIMESTAMP(TS_NS) \ + uint64_t TS_NS; \ + INFER_STATS_SET_TIMESTAMP(TS_NS); +#else +#define INFER_STATS_DECL_TIMESTAMP(TS_NS) +#define INFER_STATS_SET_TIMESTAMP(TS_NS) +#endif // TRITON_ENABLE_STATS + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/infer_trace.cc b/3rdparty/core-r22.12/src/infer_trace.cc new file mode 100644 index 0000000000000000000000000000000000000000..cce46e26283188b7afef2ca2296d206cbf2791a5 --- /dev/null +++ b/3rdparty/core-r22.12/src/infer_trace.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "infer_trace.h" + +namespace triton { namespace core { + +#ifdef TRITON_ENABLE_TRACING + +// Start the trace id at 1, because id 0 is reserved to indicate no +// parent. +std::atomic InferenceTrace::next_id_(1); + +InferenceTrace* +InferenceTrace::SpawnChildTrace() +{ + InferenceTrace* trace = new InferenceTrace( + level_, id_, activity_fn_, tensor_activity_fn_, release_fn_, userp_); + return trace; +} + +void +InferenceTrace::Release() +{ + release_fn_(reinterpret_cast(this), userp_); +} + +std::shared_ptr +InferenceTraceProxy::SpawnChildTrace() +{ + std::shared_ptr strace_proxy = + std::make_shared(trace_->SpawnChildTrace()); + return strace_proxy; +} + +#endif // TRITON_ENABLE_TRACING + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/infer_trace.h b/3rdparty/core-r22.12/src/infer_trace.h new file mode 100644 index 0000000000000000000000000000000000000000..f2696c24a2f09ea9cad268981a47b5cddda71f61 --- /dev/null +++ b/3rdparty/core-r22.12/src/infer_trace.h @@ -0,0 +1,205 @@ +// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include "constants.h" +#include "status.h" +#include "tritonserver_apis.h" + +namespace triton { namespace core { + +#ifdef TRITON_ENABLE_TRACING + +// +// InferenceTrace +// +// Interface to TRITONSERVER_InferenceTrace to report trace events. +// +class InferenceTrace { + public: + InferenceTrace( + const TRITONSERVER_InferenceTraceLevel level, const uint64_t parent_id, + TRITONSERVER_InferenceTraceActivityFn_t activity_fn, + TRITONSERVER_InferenceTraceTensorActivityFn_t tensor_activity_fn, + TRITONSERVER_InferenceTraceReleaseFn_t release_fn, void* userp) + : level_(level), id_(next_id_++), parent_id_(parent_id), + activity_fn_(activity_fn), tensor_activity_fn_(tensor_activity_fn), + release_fn_(release_fn), userp_(userp) + { + } + + InferenceTrace* SpawnChildTrace(); + + int64_t Id() const { return id_; } + int64_t ParentId() const { return parent_id_; } + + const std::string& ModelName() const { return model_name_; } + int64_t ModelVersion() const { return model_version_; } + + void SetModelName(const std::string& n) { model_name_ = n; } + void SetModelVersion(int64_t v) { model_version_ = v; } + + // Report trace activity. + void Report( + const TRITONSERVER_InferenceTraceActivity activity, uint64_t timestamp_ns) + { + if ((level_ & TRITONSERVER_TRACE_LEVEL_TIMESTAMPS) > 0) { + activity_fn_( + reinterpret_cast(this), activity, + timestamp_ns, userp_); + } + } + + // Report trace activity at the current time. + void ReportNow(const TRITONSERVER_InferenceTraceActivity activity) + { + if ((level_ & TRITONSERVER_TRACE_LEVEL_TIMESTAMPS) > 0) { + Report( + activity, std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count()); + } + } + + // Report tensor trace activity. + void ReportTensor( + const TRITONSERVER_InferenceTraceActivity activity, const char* name, + TRITONSERVER_DataType datatype, const void* base, size_t byte_size, + const int64_t* shape, uint64_t dim_count, + TRITONSERVER_MemoryType memory_type, int64_t memory_type_id) + { + if ((level_ & TRITONSERVER_TRACE_LEVEL_TENSORS) > 0) { + tensor_activity_fn_( + reinterpret_cast(this), activity, name, + datatype, base, byte_size, shape, dim_count, memory_type, + memory_type_id, userp_); + } + } + + // Release the trace. Call the trace release callback. + void Release(); + + private: + const TRITONSERVER_InferenceTraceLevel level_; + const uint64_t id_; + const uint64_t parent_id_; + + TRITONSERVER_InferenceTraceActivityFn_t activity_fn_; + TRITONSERVER_InferenceTraceTensorActivityFn_t tensor_activity_fn_; + TRITONSERVER_InferenceTraceReleaseFn_t release_fn_; + void* userp_; + + std::string model_name_; + int64_t model_version_; + + // Maintain next id statically so that trace id is unique even + // across traces + static std::atomic next_id_; +}; + +// +// InferenceTraceProxy +// +// Object attached as shared_ptr to InferenceRequest and +// InferenceResponse(s) being traced as part of a single inference +// request. +// +class InferenceTraceProxy { + public: + InferenceTraceProxy(InferenceTrace* trace) : trace_(trace) {} + ~InferenceTraceProxy() { trace_->Release(); } + int64_t Id() const { return trace_->Id(); } + int64_t ParentId() const { return trace_->ParentId(); } + const std::string& ModelName() const { return trace_->ModelName(); } + int64_t ModelVersion() const { return trace_->ModelVersion(); } + void SetModelName(const std::string& n) { trace_->SetModelName(n); } + void SetModelVersion(int64_t v) { trace_->SetModelVersion(v); } + + void Report( + const TRITONSERVER_InferenceTraceActivity activity, uint64_t timestamp_ns) + { + trace_->Report(activity, timestamp_ns); + } + + void ReportNow(const TRITONSERVER_InferenceTraceActivity activity) + { + trace_->ReportNow(activity); + } + + void ReportTensor( + const TRITONSERVER_InferenceTraceActivity activity, const char* name, + TRITONSERVER_DataType datatype, const void* base, size_t byte_size, + const int64_t* shape, uint64_t dim_count, + TRITONSERVER_MemoryType memory_type, int64_t memory_type_id) + { + trace_->ReportTensor( + activity, name, datatype, base, byte_size, shape, dim_count, + memory_type, memory_type_id); + } + + std::shared_ptr SpawnChildTrace(); + + private: + InferenceTrace* trace_; +}; + +#endif // TRITON_ENABLE_TRACING + +// +// Macros to generate trace activity +// +#ifdef TRITON_ENABLE_TRACING +#define INFER_TRACE_ACTIVITY(T, A, TS_NS) \ + { \ + const auto& trace = (T); \ + const auto ts_ns = (TS_NS); \ + if (trace != nullptr) { \ + trace->Report(A, ts_ns); \ + } \ + } +#define INFER_TRACE_ACTIVITY_NOW(T, A) \ + { \ + const auto& trace = (T); \ + if (trace != nullptr) { \ + trace->ReportNow(A); \ + } \ + } +#define INFER_TRACE_TENSOR_ACTIVITY(T, A, N, D, BA, BY, S, DI, MT, MTI) \ + { \ + const auto& trace = (T); \ + if (trace != nullptr) { \ + trace->ReportTensor(A, N, D, BA, BY, S, DI, MT, MTI); \ + } \ + } +#else +#define INFER_TRACE_ACTIVITY(T, A, TS_NS) +#define INFER_TRACE_ACTIVITY_NOW(T, A) +#define INFER_TRACE_TENSOR_ACTIVITY(T, A, N, D, BA, BY, S, DI, MT, MTI) +#endif // TRITON_ENABLE_TRACING +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/instance_queue.cc b/3rdparty/core-r22.12/src/instance_queue.cc new file mode 100644 index 0000000000000000000000000000000000000000..0aaccb7183ff7e676e8d7b99ca72a5973c380ee0 --- /dev/null +++ b/3rdparty/core-r22.12/src/instance_queue.cc @@ -0,0 +1,99 @@ +// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "instance_queue.h" + +#include "triton/common/logging.h" + +namespace triton { namespace core { + +InstanceQueue::InstanceQueue(size_t max_batch_size, uint64_t max_queue_delay_ns) + : max_batch_size_(max_batch_size), max_queue_delay_ns_(max_queue_delay_ns) +{ +} + +size_t +InstanceQueue::Size() +{ + return payload_queue_.size(); +} + +bool +InstanceQueue::Empty() +{ + return payload_queue_.empty(); +} + +void +InstanceQueue::Enqueue(const std::shared_ptr& payload) +{ + payload_queue_.push_back(payload); +} + +void +InstanceQueue::Dequeue( + std::shared_ptr* payload, + std::vector>* merged_payloads) +{ + *payload = payload_queue_.front(); + payload_queue_.pop_front(); + { + std::lock_guard exec_lock(*((*payload)->GetExecMutex())); + (*payload)->SetState(Payload::State::EXECUTING); + if ((!payload_queue_.empty()) && (max_queue_delay_ns_ > 0) && + (max_batch_size_ > 1) && (!(*payload)->IsSaturated())) { + bool continue_merge; + do { + continue_merge = false; + uint64_t now_ns = + std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + size_t batch_size = (*payload)->BatchSize(); + if ((!payload_queue_.empty()) && + (!payload_queue_.front()->IsSaturated()) && + (now_ns - payload_queue_.front()->BatcherStartNs()) > + max_queue_delay_ns_) { + std::lock_guard exec_lock( + *(payload_queue_.front()->GetExecMutex())); + payload_queue_.front()->SetState(Payload::State::EXECUTING); + size_t front_batch_size = payload_queue_.front()->BatchSize(); + if ((batch_size + front_batch_size) <= max_batch_size_) { + const auto& status = + (*payload)->MergePayload(payload_queue_.front()); + if (status.IsOk()) { + merged_payloads->push_back(payload_queue_.front()); + payload_queue_.pop_front(); + continue_merge = true; + } + } + } + } while (continue_merge); + } + } +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/instance_queue.h b/3rdparty/core-r22.12/src/instance_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..da25a460c0aa1c963c5c8141923e4a9a2628329a --- /dev/null +++ b/3rdparty/core-r22.12/src/instance_queue.h @@ -0,0 +1,57 @@ +// Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include "payload.h" + +namespace triton { namespace core { + +// +// InstanceQueue +// +// A queue implementation holding Payloads ready to be scheduled on +// model instance. +class InstanceQueue { + public: + explicit InstanceQueue(size_t max_batch_size, uint64_t max_queue_delay_ns); + + size_t Size(); + bool Empty(); + void Enqueue(const std::shared_ptr& payload); + void Dequeue( + std::shared_ptr* payload, + std::vector>* merged_payloads); + + private: + size_t max_batch_size_; + uint64_t max_queue_delay_ns_; + + std::deque> payload_queue_; + std::shared_ptr staged_payload_; + std::mutex mu_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/label_provider.cc b/3rdparty/core-r22.12/src/label_provider.cc new file mode 100644 index 0000000000000000000000000000000000000000..cff489453f3f69434b2038e346f15290eb0cc19e --- /dev/null +++ b/3rdparty/core-r22.12/src/label_provider.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "label_provider.h" + +#include +#include +#include +#include "filesystem.h" + +namespace triton { namespace core { + +const std::string& +LabelProvider::GetLabel(const std::string& name, size_t index) const +{ + static const std::string not_found; + + auto itr = label_map_.find(name); + if (itr == label_map_.end()) { + return not_found; + } + + if (itr->second.size() <= index) { + return not_found; + } + + return itr->second[index]; +} + +Status +LabelProvider::AddLabels(const std::string& name, const std::string& filepath) +{ + std::string label_file_contents; + RETURN_IF_ERROR(ReadTextFile(filepath, &label_file_contents)); + + auto p = label_map_.insert(std::make_pair(name, std::vector())); + if (!p.second) { + return Status( + Status::Code::INTERNAL, "multiple label files for '" + name + "'"); + } + + auto itr = p.first; + + std::istringstream label_file_stream(label_file_contents); + std::string line; + while (std::getline(label_file_stream, line)) { + itr->second.push_back(line); + } + + return Status::Success; +} + +const std::vector& +LabelProvider::GetLabels(const std::string& name) +{ + static const std::vector not_found; + auto itr = label_map_.find(name); + if (itr == label_map_.end()) { + return not_found; + } + return itr->second; +} + +Status +LabelProvider::AddLabels( + const std::string& name, const std::vector& labels) +{ + label_map_.emplace(name, labels); + return Status::Success; +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/label_provider.h b/3rdparty/core-r22.12/src/label_provider.h new file mode 100644 index 0000000000000000000000000000000000000000..ebbd1894772ab6e4169d6b1f6717c72239d9f4e8 --- /dev/null +++ b/3rdparty/core-r22.12/src/label_provider.h @@ -0,0 +1,65 @@ +// Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include "constants.h" +#include "status.h" + +namespace triton { namespace core { + +// Provides classification labels. +class LabelProvider { + public: + LabelProvider() = default; + + // Return the label associated with 'name' for a given + // 'index'. Return empty string if no label is available. + const std::string& GetLabel(const std::string& name, size_t index) const; + + // Associate with 'name' a set of labels initialized from a given + // 'filepath'. Within the file each label is specified on its own + // line. The first label (line 0) is the index-0 label, the second + // label (line 1) is the index-1 label, etc. + Status AddLabels(const std::string& name, const std::string& filepath); + + // Return the labels associated with 'name'. Return empty vector if no labels + // are available. + const std::vector& GetLabels(const std::string& name); + + // Associate with 'name' a set of 'labels' + Status AddLabels( + const std::string& name, const std::vector& labels); + + private: + DISALLOW_COPY_AND_ASSIGN(LabelProvider); + + std::unordered_map> label_map_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/libtritonserver.ldscript b/3rdparty/core-r22.12/src/libtritonserver.ldscript new file mode 100644 index 0000000000000000000000000000000000000000..055d5df1980898e0a7870d39c494dbcc43e878f1 --- /dev/null +++ b/3rdparty/core-r22.12/src/libtritonserver.ldscript @@ -0,0 +1,32 @@ +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +{ + global: + TRITONSERVER_*; + TRITONBACKEND_*; + TRITONREPOAGENT_*; + local: *; +}; diff --git a/3rdparty/core-r22.12/src/memory.cc b/3rdparty/core-r22.12/src/memory.cc new file mode 100644 index 0000000000000000000000000000000000000000..7d44f4b7114d25fe3cef7d22e12e9baf076c862f --- /dev/null +++ b/3rdparty/core-r22.12/src/memory.cc @@ -0,0 +1,238 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "memory.h" + +#include "pinned_memory_manager.h" +#include "triton/common/logging.h" + +#ifdef TRITON_ENABLE_GPU +#include +#include "cuda_memory_manager.h" +#endif // TRITON_ENABLE_GPU + +namespace triton { namespace core { + +// +// MemoryReference +// +MemoryReference::MemoryReference() : Memory() {} + +const char* +MemoryReference::BufferAt( + size_t idx, size_t* byte_size, TRITONSERVER_MemoryType* memory_type, + int64_t* memory_type_id) const +{ + if (idx >= buffer_.size()) { + *byte_size = 0; + *memory_type = TRITONSERVER_MEMORY_CPU; + *memory_type_id = 0; + return nullptr; + } + *memory_type = buffer_[idx].buffer_attributes_.MemoryType(); + *memory_type_id = buffer_[idx].buffer_attributes_.MemoryTypeId(); + *byte_size = buffer_[idx].buffer_attributes_.ByteSize(); + return buffer_[idx].buffer_; +} + +const char* +MemoryReference::BufferAt(size_t idx, BufferAttributes** buffer_attributes) +{ + if (idx >= buffer_.size()) { + *buffer_attributes = nullptr; + return nullptr; + } + + *buffer_attributes = &(buffer_[idx].buffer_attributes_); + return buffer_[idx].buffer_; +} + +size_t +MemoryReference::AddBuffer( + const char* buffer, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id) +{ + total_byte_size_ += byte_size; + buffer_count_++; + buffer_.emplace_back(buffer, byte_size, memory_type, memory_type_id); + return buffer_.size() - 1; +} + +size_t +MemoryReference::AddBuffer( + const char* buffer, BufferAttributes* buffer_attributes) +{ + total_byte_size_ += buffer_attributes->ByteSize(); + buffer_count_++; + buffer_.emplace_back(buffer, buffer_attributes); + return buffer_.size() - 1; +} + +size_t +MemoryReference::AddBufferFront( + const char* buffer, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id) +{ + total_byte_size_ += byte_size; + buffer_count_++; + buffer_.emplace( + buffer_.begin(), buffer, byte_size, memory_type, memory_type_id); + return buffer_.size() - 1; +} + +// +// MutableMemory +// +MutableMemory::MutableMemory( + char* buffer, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id) + : Memory(), buffer_(buffer), + buffer_attributes_( + BufferAttributes(byte_size, memory_type, memory_type_id, nullptr)) +{ + total_byte_size_ = byte_size; + buffer_count_ = (byte_size == 0) ? 0 : 1; +} + +const char* +MutableMemory::BufferAt( + size_t idx, size_t* byte_size, TRITONSERVER_MemoryType* memory_type, + int64_t* memory_type_id) const +{ + if (idx != 0) { + *byte_size = 0; + *memory_type = TRITONSERVER_MEMORY_CPU; + *memory_type_id = 0; + return nullptr; + } + *byte_size = total_byte_size_; + *memory_type = buffer_attributes_.MemoryType(); + *memory_type_id = buffer_attributes_.MemoryTypeId(); + return buffer_; +} + +const char* +MutableMemory::BufferAt(size_t idx, BufferAttributes** buffer_attributes) +{ + if (idx != 0) { + *buffer_attributes = nullptr; + return nullptr; + } + + *buffer_attributes = &buffer_attributes_; + return buffer_; +} + +char* +MutableMemory::MutableBuffer( + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id) +{ + if (memory_type != nullptr) { + *memory_type = buffer_attributes_.MemoryType(); + } + if (memory_type_id != nullptr) { + *memory_type_id = buffer_attributes_.MemoryTypeId(); + } + + return buffer_; +} + +// +// AllocatedMemory +// +AllocatedMemory::AllocatedMemory( + size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id) + : MutableMemory(nullptr, byte_size, memory_type, memory_type_id) +{ + if (total_byte_size_ != 0) { + // Allocate memory with the following fallback policy: + // CUDA memory -> pinned system memory -> non-pinned system memory + switch (buffer_attributes_.MemoryType()) { +#ifdef TRITON_ENABLE_GPU + case TRITONSERVER_MEMORY_GPU: { + auto status = CudaMemoryManager::Alloc( + (void**)&buffer_, total_byte_size_, + buffer_attributes_.MemoryTypeId()); + if (!status.IsOk()) { + static bool warning_logged = false; + if (!warning_logged) { + LOG_WARNING << status.Message() + << ", falling back to pinned system memory"; + warning_logged = true; + } + + goto pinned_memory_allocation; + } + break; + } + pinned_memory_allocation: +#endif // TRITON_ENABLE_GPU + default: { + TRITONSERVER_MemoryType memory_type = buffer_attributes_.MemoryType(); + auto status = PinnedMemoryManager::Alloc( + (void**)&buffer_, total_byte_size_, &memory_type, true); + buffer_attributes_.SetMemoryType(memory_type); + if (!status.IsOk()) { + LOG_ERROR << status.Message(); + buffer_ = nullptr; + } + break; + } + } + } + total_byte_size_ = (buffer_ == nullptr) ? 0 : total_byte_size_; +} + +AllocatedMemory::~AllocatedMemory() +{ + if (buffer_ != nullptr) { + switch (buffer_attributes_.MemoryType()) { + case TRITONSERVER_MEMORY_GPU: { +#ifdef TRITON_ENABLE_GPU + auto status = + CudaMemoryManager::Free(buffer_, buffer_attributes_.MemoryTypeId()); + if (!status.IsOk()) { + LOG_ERROR << status.Message(); + } +#endif // TRITON_ENABLE_GPU + break; + } + + default: { + auto status = PinnedMemoryManager::Free(buffer_); + if (!status.IsOk()) { + LOG_ERROR << status.Message(); + buffer_ = nullptr; + } + break; + } + } + buffer_ = nullptr; + } +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/memory.h b/3rdparty/core-r22.12/src/memory.h new file mode 100644 index 0000000000000000000000000000000000000000..fad58db09e3fbe7c33dac07034dd8c5689c280ce --- /dev/null +++ b/3rdparty/core-r22.12/src/memory.h @@ -0,0 +1,174 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include "buffer_attributes.h" +#include "constants.h" +#include "status.h" + +namespace triton { namespace core { + +// +// Memory used to access data in inference requests +// +class Memory { + public: + // Get the 'idx'-th data block in the buffer. Using index to avoid + // maintaining internal state such that one buffer can be shared + // across multiple providers. + // 'idx' zero base index. Valid indices are continuous. + // 'byte_size' returns the byte size of the chunk of bytes. + // 'memory_type' returns the memory type of the chunk of bytes. + // 'memory_type_id' returns the memory type id of the chunk of bytes. + // Return the pointer to the data block. Returns nullptr if 'idx' is + // out of range + virtual const char* BufferAt( + size_t idx, size_t* byte_size, TRITONSERVER_MemoryType* memory_type, + int64_t* memory_type_id) const = 0; + + // Similar to the above BufferAt but with BufferAttributes. + virtual const char* BufferAt( + size_t idx, BufferAttributes** buffer_attributes) = 0; + + // Get the number of contiguous buffers composing the memory. + size_t BufferCount() const { return buffer_count_; } + + // Return the total byte size of the data buffer + size_t TotalByteSize() const { return total_byte_size_; } + + protected: + Memory() : total_byte_size_(0), buffer_count_(0) {} + size_t total_byte_size_; + size_t buffer_count_; +}; + +// +// MemoryReference +// +class MemoryReference : public Memory { + public: + // Create a read-only data buffer as a reference to other data buffer + MemoryReference(); + + //\see Memory::BufferAt() + const char* BufferAt( + size_t idx, size_t* byte_size, TRITONSERVER_MemoryType* memory_type, + int64_t* memory_type_id) const override; + + const char* BufferAt( + size_t idx, BufferAttributes** buffer_attributes) override; + + // Add a 'buffer' with 'byte_size' as part of this data buffer + // Return the index of the buffer + size_t AddBuffer( + const char* buffer, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id); + + size_t AddBuffer(const char* buffer, BufferAttributes* buffer_attributes); + + // Add a 'buffer' with 'byte_size' as part of this data buffer in the front + // Return the index of the buffer + size_t AddBufferFront( + const char* buffer, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id); + + private: + struct Block { + Block( + const char* buffer, size_t byte_size, + TRITONSERVER_MemoryType memory_type, int64_t memory_type_id) + : buffer_(buffer), buffer_attributes_(BufferAttributes( + byte_size, memory_type, memory_type_id, nullptr)) + { + } + + Block(const char* buffer, BufferAttributes* buffer_attributes) + : buffer_(buffer), buffer_attributes_(*buffer_attributes) + { + } + const char* buffer_; + BufferAttributes buffer_attributes_; + }; + std::vector buffer_; +}; + +// +// MutableMemory +// +class MutableMemory : public Memory { + public: + // Create a mutable data buffer referencing to other data buffer. + MutableMemory( + char* buffer, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id); + + virtual ~MutableMemory() {} + + //\see Memory::BufferAt() + const char* BufferAt( + size_t idx, size_t* byte_size, TRITONSERVER_MemoryType* memory_type, + int64_t* memory_type_id) const override; + + //\see Memory::BufferAt() + const char* BufferAt( + size_t idx, BufferAttributes** buffer_attributes) override; + + // Return a pointer to the base address of the mutable buffer. If + // non-null 'memory_type' returns the memory type of the chunk of + // bytes. If non-null 'memory_type_id' returns the memory type id of + // the chunk of bytes. + char* MutableBuffer( + TRITONSERVER_MemoryType* memory_type = nullptr, + int64_t* memory_type_id = nullptr); + + DISALLOW_COPY_AND_ASSIGN(MutableMemory); + + protected: + MutableMemory() : Memory() {} + + char* buffer_; + BufferAttributes buffer_attributes_; +}; + +// +// AllocatedMemory +// +class AllocatedMemory : public MutableMemory { + public: + // Create a continuous data buffer with 'byte_size', 'memory_type' and + // 'memory_type_id'. Note that the buffer may be created on different memeory + // type and memory type id if the original request type and id can not be + // satisfied, thus the function caller should always check the actual memory + // type and memory type id before use. + AllocatedMemory( + size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id); + + ~AllocatedMemory() override; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/metric_family.cc b/3rdparty/core-r22.12/src/metric_family.cc new file mode 100644 index 0000000000000000000000000000000000000000..4ae3a8174088f42f9cc892ac4997e7bde74e9ee3 --- /dev/null +++ b/3rdparty/core-r22.12/src/metric_family.cc @@ -0,0 +1,321 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifdef TRITON_ENABLE_METRICS + +#include "metric_family.h" +#include "metrics.h" +#include "triton/common/logging.h" + +namespace triton { namespace core { + +// +// Implementation for TRITONSERVER_MetricFamily. +// +MetricFamily::MetricFamily( + TRITONSERVER_MetricKind kind, const char* name, const char* description) +{ + auto registry = Metrics::GetRegistry(); + + switch (kind) { + case TRITONSERVER_METRIC_KIND_COUNTER: + family_ = reinterpret_cast(&prometheus::BuildCounter() + .Name(name) + .Help(description) + .Register(*registry)); + break; + case TRITONSERVER_METRIC_KIND_GAUGE: + family_ = reinterpret_cast(&prometheus::BuildGauge() + .Name(name) + .Help(description) + .Register(*registry)); + break; + default: + throw std::invalid_argument( + "Unsupported kind passed to MetricFamily constructor."); + } + + kind_ = kind; +} + +void* +MetricFamily::Add(std::map label_map, Metric* metric) +{ + void* prom_metric = nullptr; + switch (kind_) { + case TRITONSERVER_METRIC_KIND_COUNTER: { + auto counter_family_ptr = + reinterpret_cast*>(family_); + auto counter_ptr = &counter_family_ptr->Add(label_map); + prom_metric = reinterpret_cast(counter_ptr); + break; + } + case TRITONSERVER_METRIC_KIND_GAUGE: { + auto gauge_family_ptr = + reinterpret_cast*>(family_); + auto gauge_ptr = &gauge_family_ptr->Add(label_map); + prom_metric = reinterpret_cast(gauge_ptr); + break; + } + default: + throw std::invalid_argument( + "Unsupported family kind passed to Metric constructor."); + } + + std::lock_guard lk(metric_mtx_); + ++prom_metric_ref_cnt_[prom_metric]; + child_metrics_.insert(metric); + return prom_metric; +} + +void +MetricFamily::Remove(void* prom_metric, Metric* metric) +{ + { + // Remove reference to dependent Metric object + std::lock_guard lk(metric_mtx_); + child_metrics_.erase(metric); + } + + if (prom_metric == nullptr) { + return; + } + + { + std::lock_guard lk(metric_mtx_); + const auto it = prom_metric_ref_cnt_.find(prom_metric); + if (it != prom_metric_ref_cnt_.end()) { + --it->second; + if (it->second == 0) { + prom_metric_ref_cnt_.erase(it); + } else { + // Done as it is not the last reference + return; + } + } + } + + switch (kind_) { + case TRITONSERVER_METRIC_KIND_COUNTER: { + auto counter_family_ptr = + reinterpret_cast*>(family_); + auto counter_ptr = reinterpret_cast(prom_metric); + counter_family_ptr->Remove(counter_ptr); + break; + } + case TRITONSERVER_METRIC_KIND_GAUGE: { + auto gauge_family_ptr = + reinterpret_cast*>(family_); + auto gauge_ptr = reinterpret_cast(prom_metric); + gauge_family_ptr->Remove(gauge_ptr); + break; + } + default: + // Invalid kind should be caught in constructor + LOG_ERROR << "Unsupported kind in Metric destructor."; + break; + } +} + +void +MetricFamily::InvalidateReferences() +{ + std::lock_guard lk(metric_mtx_); + for (auto& metric : child_metrics_) { + if (metric != nullptr) { + metric->Invalidate(); + } + } + child_metrics_.clear(); +} + +MetricFamily::~MetricFamily() +{ + if (NumMetrics() > 0) { + LOG_WARNING << "MetricFamily was deleted before its child Metrics, this " + "should not happen. Make sure to delete all child Metrics " + "before deleting their MetricFamily."; + } + InvalidateReferences(); + // DLIS-4072: Support for removing metric families from registry +} + +// +// Implementation for TRITONSERVER_Metric. +// +Metric::Metric( + TRITONSERVER_MetricFamily* family, + std::vector labels) +{ + family_ = reinterpret_cast(family); + kind_ = family_->Kind(); + + // Create map of labels from InferenceParameters + std::map label_map; + for (const auto& param : labels) { + if (param->Type() != TRITONSERVER_PARAMETER_STRING) { + throw std::invalid_argument( + "Parameter [" + param->Name() + + "] must have a type of TRITONSERVER_PARAMETER_STRING to be " + "added as a label."); + } + + label_map[param->Name()] = + std::string(reinterpret_cast(param->ValuePointer())); + } + + metric_ = family_->Add(label_map, this); +} + +Metric::~Metric() +{ + if (family_ != nullptr) { + family_->Remove(metric_, this); + } else { + LOG_WARNING << "Corresponding MetricFamily was deleted before this Metric, " + "this should not happen. Make sure to delete a Metric " + "before deleting its MetricFamily."; + } + // Catch lifetime management / invalid reference issues + Invalidate(); +} + +void +Metric::Invalidate() +{ + family_ = nullptr; + metric_ = nullptr; +} + +TRITONSERVER_Error* +Metric::Value(double* value) +{ + if (metric_ == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "Could not get metric value. Metric has been invalidated."); + } + + switch (kind_) { + case TRITONSERVER_METRIC_KIND_COUNTER: { + auto counter_ptr = reinterpret_cast(metric_); + LOG_VERBOSE(1) << "SETTING COUNTER METRIC FROM: " << *value << " to " + << counter_ptr->Value(); + *value = counter_ptr->Value(); + break; + } + case TRITONSERVER_METRIC_KIND_GAUGE: { + auto gauge_ptr = reinterpret_cast(metric_); + LOG_VERBOSE(1) << "SETTING GAUGE METRIC FROM: " << *value << " to " + << gauge_ptr->Value(); + *value = gauge_ptr->Value(); + break; + } + default: + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "Unsupported TRITONSERVER_MetricKind"); + } + + return nullptr; // Success +} + +TRITONSERVER_Error* +Metric::Increment(double value) +{ + if (metric_ == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "Could not increment metric value. Metric has been invalidated."); + } + + switch (kind_) { + case TRITONSERVER_METRIC_KIND_COUNTER: { + if (value < 0.0) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "TRITONSERVER_METRIC_KIND_COUNTER can only be incremented " + "monotonically by non-negative values."); + } + + auto counter_ptr = reinterpret_cast(metric_); + counter_ptr->Increment(value); + break; + } + case TRITONSERVER_METRIC_KIND_GAUGE: { + auto gauge_ptr = reinterpret_cast(metric_); + // Gauge::Increment works for both positive and negative values as of + // prometheus-cpp v1.0 but for now on v0.7 we defer call to + // Increment/Decrement based on the sign of value + // https://github.com/jupp0r/prometheus-cpp/blob/master/core/src/gauge.cc + if (value < 0.0) { + gauge_ptr->Decrement(-1.0 * value); + } else { + gauge_ptr->Increment(value); + } + break; + } + default: + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "Unsupported TRITONSERVER_MetricKind"); + } + + return nullptr; // Success +} + +TRITONSERVER_Error* +Metric::Set(double value) +{ + if (metric_ == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "Could not set metric value. Metric has been invalidated."); + } + + switch (kind_) { + case TRITONSERVER_METRIC_KIND_COUNTER: { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "TRITONSERVER_METRIC_KIND_COUNTER does not support Set"); + } + case TRITONSERVER_METRIC_KIND_GAUGE: { + auto gauge_ptr = reinterpret_cast(metric_); + gauge_ptr->Set(value); + break; + } + default: + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "Unsupported TRITONSERVER_MetricKind"); + } + + return nullptr; // Success +} + +}} // namespace triton::core + +#endif // TRITON_ENABLE_METRICS diff --git a/3rdparty/core-r22.12/src/metric_family.h b/3rdparty/core-r22.12/src/metric_family.h new file mode 100644 index 0000000000000000000000000000000000000000..b5d09d864cf30a90455357b9f42ddf3b11835ad0 --- /dev/null +++ b/3rdparty/core-r22.12/src/metric_family.h @@ -0,0 +1,111 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#ifdef TRITON_ENABLE_METRICS + +#include +#include +#include + +#include "infer_parameter.h" +#include "prometheus/registry.h" +#include "tritonserver_apis.h" + +namespace triton { namespace core { + +// +// Implementation for TRITONSERVER_MetricFamily. +// +class Metric; +class MetricFamily { + public: + MetricFamily( + TRITONSERVER_MetricKind kind, const char* name, const char* description); + ~MetricFamily(); + + void* Family() const { return family_; } + TRITONSERVER_MetricKind Kind() const { return kind_; } + + void* Add(std::map label_map, Metric* metric); + void Remove(void* prom_metric, Metric* metric); + + int NumMetrics() + { + std::lock_guard lk(metric_mtx_); + return child_metrics_.size(); + } + + private: + // If a MetricFamily is deleted before its dependent Metric, we want to + // invalidate the reference so we don't access invalid memory. + void InvalidateReferences(); + + void* family_; + TRITONSERVER_MetricKind kind_; + // Synchronize access of related metric objects + std::mutex metric_mtx_; + // Prometheus returns the existing metric pointer if the metric with the same + // set of labels are requested, as a result, different Metric objects may + // refer to the same prometheus metric. So we must track the reference count + // of the metric and request prometheus to remove it only when all references + // are released. + std::unordered_map prom_metric_ref_cnt_; + // Maintain references to metrics created from this metric family to + // invalidate their references if a family is deleted before its metric + std::set child_metrics_; +}; + +// +// Implementation for TRITONSERVER_Metric. +// +class Metric { + public: + Metric( + TRITONSERVER_MetricFamily* family, + std::vector labels); + ~Metric(); + + MetricFamily* Family() const { return family_; } + TRITONSERVER_MetricKind Kind() const { return kind_; } + + TRITONSERVER_Error* Value(double* value); + TRITONSERVER_Error* Increment(double value); + TRITONSERVER_Error* Set(double value); + + // If a MetricFamily is deleted before its dependent Metric, we want to + // invalidate the references so we don't access invalid memory. + void Invalidate(); + + private: + void* metric_; + MetricFamily* family_; + TRITONSERVER_MetricKind kind_; +}; + +}} // namespace triton::core + +#endif // TRITON_ENABLE_METRICS diff --git a/3rdparty/core-r22.12/src/metric_model_reporter.cc b/3rdparty/core-r22.12/src/metric_model_reporter.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f0905ef7331b23bdba95009f2dea16595c7077c --- /dev/null +++ b/3rdparty/core-r22.12/src/metric_model_reporter.cc @@ -0,0 +1,168 @@ +// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "metric_model_reporter.h" + +#ifdef TRITON_ENABLE_METRICS + +#include "constants.h" +#include "metrics.h" + +namespace triton { namespace core { + +Status +MetricModelReporter::Create( + const std::string& model_name, const int64_t model_version, + const int device, const triton::common::MetricTagsMap& model_tags, + std::shared_ptr* metric_model_reporter) +{ + static std::mutex mtx; + static std::unordered_map> + reporter_map; + + std::map labels; + GetMetricLabels(&labels, model_name, model_version, device, model_tags); + auto hash_labels = Metrics::HashLabels(labels); + + std::lock_guard lock(mtx); + + const auto& itr = reporter_map.find(hash_labels); + if (itr != reporter_map.end()) { + // Found in map. If the weak_ptr is still valid that means that + // there are other models using the reporter and we just reuse that + // same reporter. If the weak_ptr is not valid then we need to remove + // the weak_ptr from the map and create the reporter again. + *metric_model_reporter = itr->second.lock(); + if (*metric_model_reporter != nullptr) { + return Status::Success; + } + + reporter_map.erase(itr); + } + + metric_model_reporter->reset( + new MetricModelReporter(model_name, model_version, device, model_tags)); + reporter_map.insert({hash_labels, *metric_model_reporter}); + return Status::Success; +} + +MetricModelReporter::MetricModelReporter( + const std::string& model_name, const int64_t model_version, + const int device, const triton::common::MetricTagsMap& model_tags) +{ + std::map labels; + GetMetricLabels(&labels, model_name, model_version, device, model_tags); + + metric_inf_success_ = + CreateCounterMetric(Metrics::FamilyInferenceSuccess(), labels); + metric_inf_failure_ = + CreateCounterMetric(Metrics::FamilyInferenceFailure(), labels); + metric_inf_count_ = + CreateCounterMetric(Metrics::FamilyInferenceCount(), labels); + metric_inf_exec_count_ = + CreateCounterMetric(Metrics::FamilyInferenceExecutionCount(), labels); + metric_inf_request_duration_us_ = + CreateCounterMetric(Metrics::FamilyInferenceRequestDuration(), labels); + metric_inf_queue_duration_us_ = + CreateCounterMetric(Metrics::FamilyInferenceQueueDuration(), labels); + metric_inf_compute_input_duration_us_ = CreateCounterMetric( + Metrics::FamilyInferenceComputeInputDuration(), labels); + metric_inf_compute_infer_duration_us_ = CreateCounterMetric( + Metrics::FamilyInferenceComputeInferDuration(), labels); + metric_inf_compute_output_duration_us_ = CreateCounterMetric( + Metrics::FamilyInferenceComputeOutputDuration(), labels); + metric_cache_hit_count_ = + CreateCounterMetric(Metrics::FamilyCacheHitCount(), labels); + metric_cache_hit_lookup_duration_us_ = + CreateCounterMetric(Metrics::FamilyCacheHitLookupDuration(), labels); + metric_cache_miss_count_ = + CreateCounterMetric(Metrics::FamilyCacheMissCount(), labels); + metric_cache_miss_lookup_duration_us_ = + CreateCounterMetric(Metrics::FamilyCacheMissLookupDuration(), labels); + metric_cache_miss_insertion_duration_us_ = + CreateCounterMetric(Metrics::FamilyCacheMissInsertionDuration(), labels); +} + +MetricModelReporter::~MetricModelReporter() +{ + Metrics::FamilyInferenceSuccess().Remove(metric_inf_success_); + Metrics::FamilyInferenceFailure().Remove(metric_inf_failure_); + Metrics::FamilyInferenceCount().Remove(metric_inf_count_); + Metrics::FamilyInferenceExecutionCount().Remove(metric_inf_exec_count_); + Metrics::FamilyInferenceRequestDuration().Remove( + metric_inf_request_duration_us_); + Metrics::FamilyInferenceQueueDuration().Remove(metric_inf_queue_duration_us_); + Metrics::FamilyInferenceComputeInputDuration().Remove( + metric_inf_compute_input_duration_us_); + Metrics::FamilyInferenceComputeInferDuration().Remove( + metric_inf_compute_infer_duration_us_); + Metrics::FamilyInferenceComputeOutputDuration().Remove( + metric_inf_compute_output_duration_us_); + Metrics::FamilyCacheHitCount().Remove(metric_cache_hit_count_); + Metrics::FamilyCacheHitLookupDuration().Remove( + metric_cache_hit_lookup_duration_us_); + Metrics::FamilyCacheMissCount().Remove(metric_cache_miss_count_); + Metrics::FamilyCacheMissInsertionDuration().Remove( + metric_cache_miss_insertion_duration_us_); +} + +void +MetricModelReporter::GetMetricLabels( + std::map* labels, const std::string& model_name, + const int64_t model_version, const int device, + const triton::common::MetricTagsMap& model_tags) +{ + labels->insert(std::map::value_type( + std::string(kMetricsLabelModelName), model_name)); + labels->insert(std::map::value_type( + std::string(kMetricsLabelModelVersion), std::to_string(model_version))); + for (const auto& tag : model_tags) { + labels->insert(std::map::value_type( + "_" + tag.first, tag.second)); + } + + // 'device' can be < 0 to indicate that the GPU is not known. In + // that case use a metric that doesn't have the gpu_uuid label. + if (device >= 0) { + std::string uuid; + if (Metrics::UUIDForCudaDevice(device, &uuid)) { + labels->insert(std::map::value_type( + std::string(kMetricsLabelGpuUuid), uuid)); + } + } +} + +prometheus::Counter* +MetricModelReporter::CreateCounterMetric( + prometheus::Family& family, + const std::map& labels) +{ + return &family.Add(labels); +} + +}} // namespace triton::core + +#endif // TRITON_ENABLE_METRICS diff --git a/3rdparty/core-r22.12/src/metric_model_reporter.h b/3rdparty/core-r22.12/src/metric_model_reporter.h new file mode 100644 index 0000000000000000000000000000000000000000..282152828fda4dc1df9932d42030dea93941120f --- /dev/null +++ b/3rdparty/core-r22.12/src/metric_model_reporter.h @@ -0,0 +1,138 @@ +// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include "status.h" +#include "triton/common/model_config.h" + +#ifdef TRITON_ENABLE_METRICS +#include "prometheus/registry.h" +#endif // TRITON_ENABLE_METRICS + +namespace triton { namespace core { + +// +// Interface for a metric reporter for a given version of a model. +// +class MetricModelReporter { + public: +#ifdef TRITON_ENABLE_METRICS + static Status Create( + const std::string& model_name, const int64_t model_version, + const int device, const triton::common::MetricTagsMap& model_tags, + std::shared_ptr* metric_model_reporter); + + ~MetricModelReporter(); + + // Get a metric for the given model, version and GPU index. + prometheus::Counter& MetricInferenceSuccess() const + { + return *metric_inf_success_; + } + prometheus::Counter& MetricInferenceFailure() const + { + return *metric_inf_failure_; + } + prometheus::Counter& MetricInferenceCount() const + { + return *metric_inf_count_; + } + prometheus::Counter& MetricInferenceExecutionCount() const + { + return *metric_inf_exec_count_; + } + prometheus::Counter& MetricInferenceRequestDuration() const + { + return *metric_inf_request_duration_us_; + } + prometheus::Counter& MetricInferenceQueueDuration() const + { + return *metric_inf_queue_duration_us_; + } + prometheus::Counter& MetricInferenceComputeInputDuration() const + { + return *metric_inf_compute_input_duration_us_; + } + prometheus::Counter& MetricInferenceComputeInferDuration() const + { + return *metric_inf_compute_infer_duration_us_; + } + prometheus::Counter& MetricInferenceComputeOutputDuration() const + { + return *metric_inf_compute_output_duration_us_; + } + prometheus::Counter& MetricCacheHitCount() const + { + return *metric_cache_hit_count_; + } + prometheus::Counter& MetricCacheHitLookupDuration() const + { + return *metric_cache_hit_lookup_duration_us_; + } + prometheus::Counter& MetricCacheMissCount() const + { + return *metric_cache_miss_count_; + } + prometheus::Counter& MetricCacheMissLookupDuration() const + { + return *metric_cache_miss_lookup_duration_us_; + } + prometheus::Counter& MetricCacheMissInsertionDuration() const + { + return *metric_cache_miss_insertion_duration_us_; + } + + private: + MetricModelReporter( + const std::string& model_name, const int64_t model_version, + const int device, const triton::common::MetricTagsMap& model_tags); + + static void GetMetricLabels( + std::map* labels, const std::string& model_name, + const int64_t model_version, const int device, + const triton::common::MetricTagsMap& model_tags); + prometheus::Counter* CreateCounterMetric( + prometheus::Family& family, + const std::map& labels); + + prometheus::Counter* metric_inf_success_; + prometheus::Counter* metric_inf_failure_; + prometheus::Counter* metric_inf_count_; + prometheus::Counter* metric_inf_exec_count_; + prometheus::Counter* metric_inf_request_duration_us_; + prometheus::Counter* metric_inf_queue_duration_us_; + prometheus::Counter* metric_inf_compute_input_duration_us_; + prometheus::Counter* metric_inf_compute_infer_duration_us_; + prometheus::Counter* metric_inf_compute_output_duration_us_; + prometheus::Counter* metric_cache_hit_count_; + prometheus::Counter* metric_cache_hit_lookup_duration_us_; + prometheus::Counter* metric_cache_miss_count_; + prometheus::Counter* metric_cache_miss_lookup_duration_us_; + prometheus::Counter* metric_cache_miss_insertion_duration_us_; +#endif // TRITON_ENABLE_METRICS +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/metrics.cc b/3rdparty/core-r22.12/src/metrics.cc new file mode 100644 index 0000000000000000000000000000000000000000..0be4907e9b7b000ba54b677133a9e8b3e6814dcf --- /dev/null +++ b/3rdparty/core-r22.12/src/metrics.cc @@ -0,0 +1,1035 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// + +#ifdef TRITON_ENABLE_METRICS + +#include "metrics.h" + +#include +#include "constants.h" +#include "prometheus/detail/utils.h" +#include "triton/common/logging.h" + +#ifdef TRITON_ENABLE_METRICS_GPU +#include +#include +#include +#include +#include +#endif // TRITON_ENABLE_METRICS_GPU + +namespace triton { namespace core { + +Metrics::Metrics() + : registry_(std::make_shared()), + serializer_(new prometheus::TextSerializer()), + inf_success_family_( + prometheus::BuildCounter() + .Name("nv_inference_request_success") + .Help("Number of successful inference requests, all batch sizes") + .Register(*registry_)), + inf_failure_family_( + prometheus::BuildCounter() + .Name("nv_inference_request_failure") + .Help("Number of failed inference requests, all batch sizes") + .Register(*registry_)), + inf_count_family_(prometheus::BuildCounter() + .Name("nv_inference_count") + .Help("Number of inferences performed (does not " + "include cached requests)") + .Register(*registry_)), + inf_count_exec_family_(prometheus::BuildCounter() + .Name("nv_inference_exec_count") + .Help("Number of model executions performed " + "(does not include cached requests)") + .Register(*registry_)), + inf_request_duration_us_family_( + prometheus::BuildCounter() + .Name("nv_inference_request_duration_us") + .Help("Cumulative inference request duration in microseconds " + "(includes cached requests)") + .Register(*registry_)), + inf_queue_duration_us_family_( + prometheus::BuildCounter() + .Name("nv_inference_queue_duration_us") + .Help("Cumulative inference queuing duration in microseconds " + "(includes cached requests)") + .Register(*registry_)), + inf_compute_input_duration_us_family_( + prometheus::BuildCounter() + .Name("nv_inference_compute_input_duration_us") + .Help("Cumulative compute input duration in microseconds (does " + "not include cached requests)") + .Register(*registry_)), + inf_compute_infer_duration_us_family_( + prometheus::BuildCounter() + .Name("nv_inference_compute_infer_duration_us") + .Help("Cumulative compute inference duration in microseconds " + "(does not include cached requests)") + .Register(*registry_)), + inf_compute_output_duration_us_family_( + prometheus::BuildCounter() + .Name("nv_inference_compute_output_duration_us") + .Help("Cumulative inference compute output duration in " + "microseconds (does not include cached requests)") + .Register(*registry_)), + cache_num_entries_family_( + prometheus::BuildGauge() + .Name("nv_cache_num_entries") + .Help("Number of responses stored in response cache") + .Register(*registry_)), + cache_num_lookups_family_( + prometheus::BuildGauge() + .Name("nv_cache_num_lookups") + .Help("Number of cache lookups in response cache") + .Register(*registry_)), + cache_num_hits_family_(prometheus::BuildGauge() + .Name("nv_cache_num_hits") + .Help("Number of cache hits in response cache") + .Register(*registry_)), + cache_num_misses_family_( + prometheus::BuildGauge() + .Name("nv_cache_num_misses") + .Help("Number of cache misses in response cache") + .Register(*registry_)), + cache_num_evictions_family_( + prometheus::BuildGauge() + .Name("nv_cache_num_evictions") + .Help("Number of cache evictions in response cache") + .Register(*registry_)), + cache_lookup_duration_us_family_( + prometheus::BuildGauge() + .Name("nv_cache_lookup_duration") + .Help( + "Total cache lookup duration (hit and miss), in microseconds") + .Register(*registry_)), + cache_insertion_duration_us_family_( + prometheus::BuildGauge() + .Name("nv_cache_insertion_duration") + .Help("Total cache insertion duration, in microseconds") + .Register(*registry_)), + cache_util_family_(prometheus::BuildGauge() + .Name("nv_cache_util") + .Help("Cache utilization [0.0 - 1.0]") + .Register(*registry_)), + // Per-model cache metric families + cache_num_hits_model_family_(prometheus::BuildCounter() + .Name("nv_cache_num_hits_per_model") + .Help("Number of cache hits per model") + .Register(*registry_)), + cache_hit_lookup_duration_us_model_family_( + prometheus::BuildCounter() + .Name("nv_cache_hit_lookup_duration_per_model") + .Help( + "Total cache hit lookup duration per model, in microseconds") + .Register(*registry_)), + cache_num_misses_model_family_( + prometheus::BuildCounter() + .Name("nv_cache_num_misses_per_model") + .Help("Number of cache misses per model") + .Register(*registry_)), + cache_miss_lookup_duration_us_model_family_( + prometheus::BuildCounter() + .Name("nv_cache_miss_lookup_duration_per_model") + .Help( + "Total cache miss lookup duration per model, in microseconds") + .Register(*registry_)), + cache_miss_insertion_duration_us_model_family_( + prometheus::BuildCounter() + .Name("nv_cache_miss_insertion_duration_per_model") + .Help("Total cache miss insertion duration per model, in " + "microseconds") + .Register(*registry_)), + +#ifdef TRITON_ENABLE_METRICS_GPU + gpu_utilization_family_(prometheus::BuildGauge() + .Name("nv_gpu_utilization") + .Help("GPU utilization rate [0.0 - 1.0)") + .Register(*registry_)), + gpu_memory_total_family_(prometheus::BuildGauge() + .Name("nv_gpu_memory_total_bytes") + .Help("GPU total memory, in bytes") + .Register(*registry_)), + gpu_memory_used_family_(prometheus::BuildGauge() + .Name("nv_gpu_memory_used_bytes") + .Help("GPU used memory, in bytes") + .Register(*registry_)), + gpu_power_usage_family_(prometheus::BuildGauge() + .Name("nv_gpu_power_usage") + .Help("GPU power usage in watts") + .Register(*registry_)), + gpu_power_limit_family_(prometheus::BuildGauge() + .Name("nv_gpu_power_limit") + .Help("GPU power management limit in watts") + .Register(*registry_)), + gpu_energy_consumption_family_( + prometheus::BuildCounter() + .Name("nv_energy_consumption") + .Help("GPU energy consumption in joules since the Triton Server " + "started") + .Register(*registry_)), +#endif // TRITON_ENABLE_METRICS_GPU + +#ifdef TRITON_ENABLE_METRICS_CPU + cpu_utilization_family_(prometheus::BuildGauge() + .Name("nv_cpu_utilization") + .Help("CPU utilization rate [0.0 - 1.0]") + .Register(*registry_)), + cpu_memory_total_family_(prometheus::BuildGauge() + .Name("nv_cpu_memory_total_bytes") + .Help("CPU total memory (RAM), in bytes") + .Register(*registry_)), + cpu_memory_used_family_(prometheus::BuildGauge() + .Name("nv_cpu_memory_used_bytes") + .Help("CPU used memory (RAM), in bytes") + .Register(*registry_)), +#endif // TRITON_ENABLE_METRICS_CPU + + metrics_enabled_(false), gpu_metrics_enabled_(false), + cpu_metrics_enabled_(false), cache_metrics_enabled_(false), + metrics_interval_ms_(2000) +{ +} + +static prometheus::detail::LabelHasher label_hasher_; + +size_t +Metrics::HashLabels(const std::map& labels) +{ + return label_hasher_(labels); +} + +Metrics::~Metrics() +{ + // Signal the cache thread to exit and then wait for it... + if (poll_thread_ != nullptr) { + poll_thread_exit_.store(true); + poll_thread_->join(); +#ifdef TRITON_ENABLE_METRICS_GPU + if (dcgm_metadata_.dcgm_initialized_) { + dcgmReturn_t derr; + // Group destroy will return an error if groupId invalid or dcgm not + // initialized or configured correctly + derr = dcgmGroupDestroy( + dcgm_metadata_.dcgm_handle_, dcgm_metadata_.groupId_); + if (derr != DCGM_ST_OK) { + LOG_WARNING << "Unable to destroy DCGM group: " << errorString(derr); + } + + // Stop and shutdown DCGM + if (dcgm_metadata_.standalone_) { + derr = dcgmDisconnect(dcgm_metadata_.dcgm_handle_); + } else { + derr = dcgmStopEmbedded(dcgm_metadata_.dcgm_handle_); + } + if (derr != DCGM_ST_OK) { + LOG_WARNING << "Unable to stop DCGM: " << errorString(derr); + } + derr = dcgmShutdown(); + if (derr != DCGM_ST_OK) { + LOG_WARNING << "Unable to shutdown DCGM: " << errorString(derr); + } + } +#endif // TRITON_ENABLE_METRICS_GPU + } +} + +bool +Metrics::Enabled() +{ + auto singleton = GetSingleton(); + return singleton->metrics_enabled_; +} + +void +Metrics::EnableMetrics() +{ + auto singleton = GetSingleton(); + singleton->metrics_enabled_ = true; +} + +void +Metrics::EnableCacheMetrics( + std::shared_ptr response_cache) +{ + auto singleton = GetSingleton(); + // Ensure thread-safe enabling of Cache Metrics + std::lock_guard lock(singleton->metrics_enabling_); + if (singleton->cache_metrics_enabled_) { + return; + } + + singleton->InitializeCacheMetrics(response_cache); + singleton->cache_metrics_enabled_ = true; +} + +void +Metrics::EnableGPUMetrics() +{ + auto singleton = GetSingleton(); + // Ensure thread-safe enabling of GPU Metrics + std::lock_guard lock(singleton->metrics_enabling_); + if (singleton->gpu_metrics_enabled_) { + return; + } + + if (std::getenv("TRITON_SERVER_CPU_ONLY") == nullptr) { + singleton->InitializeDcgmMetrics(); + } + + singleton->gpu_metrics_enabled_ = true; +} + +void +Metrics::EnableCpuMetrics() +{ + auto singleton = GetSingleton(); + // Ensure thread-safe enabling of CPU Metrics + std::lock_guard lock(singleton->metrics_enabling_); + if (singleton->cpu_metrics_enabled_) { + return; + } + + singleton->InitializeCpuMetrics(); + singleton->cpu_metrics_enabled_ = true; +} + +void +Metrics::SetMetricsInterval(uint64_t metrics_interval_ms) +{ + auto singleton = GetSingleton(); + singleton->metrics_interval_ms_ = metrics_interval_ms; +} + +void +Metrics::StartPollingThreadSingleton( + std::shared_ptr response_cache) +{ + auto singleton = GetSingleton(); + + // Ensure thread-safe start of polling thread + std::lock_guard lock(singleton->poll_thread_starting_); + if (singleton->poll_thread_started_) { + return; + } + + // Start thread for polling cache/dcgm metrics + singleton->StartPollingThread(response_cache); + + // Toggle flag so this function is only executed once + singleton->poll_thread_started_ = true; +} + +bool +Metrics::StartPollingThread( + std::shared_ptr response_cache) +{ + // Nothing to poll if no polling metrics enabled, don't spawn a thread + if (!cache_metrics_enabled_ && !gpu_metrics_enabled_ && + !cpu_metrics_enabled_) { + LOG_WARNING << "No polling metrics (CPU, GPU, Cache) are enabled. Will not " + "poll for them."; + return false; + } + poll_thread_exit_.store(false); + + // Start a separate thread for polling metrics at specified interval + poll_thread_.reset(new std::thread([this, response_cache] { + // Thread will update metrics indefinitely until exit flag set + while (!poll_thread_exit_.load()) { + // Sleep for metric interval + std::this_thread::sleep_for( + std::chrono::milliseconds(metrics_interval_ms_ / 2)); + + // Poll Response Cache metrics + if (cache_metrics_enabled_ && response_cache != nullptr) { + PollCacheMetrics(response_cache); + } + +#ifdef TRITON_ENABLE_METRICS_GPU + // Poll DCGM GPU metrics + if (gpu_metrics_enabled_ && + dcgm_metadata_.available_cuda_gpu_ids_.size() > 0) { + PollDcgmMetrics(); + } +#endif // TRITON_ENABLE_METRICS_GPU + +#ifdef TRITON_ENABLE_METRICS_CPU + if (cpu_metrics_enabled_) { + PollCpuMetrics(); + } +#endif // TRITON_ENABLE_METRICS_CPU + } + })); + + return true; +} + +bool +Metrics::PollCacheMetrics(std::shared_ptr response_cache) +{ + if (response_cache == nullptr) { + LOG_WARNING << "error polling cache metrics, cache metrics will not be " + << "available: cache was nullptr"; + return false; + } + + // Update global cache metrics + cache_num_entries_global_->Set(response_cache->NumEntries()); + cache_num_lookups_global_->Set(response_cache->NumLookups()); + cache_num_hits_global_->Set(response_cache->NumHits()); + cache_num_misses_global_->Set(response_cache->NumMisses()); + cache_num_evictions_global_->Set(response_cache->NumEvictions()); + cache_lookup_duration_us_global_->Set( + response_cache->TotalLookupLatencyNs() / 1000); + cache_insertion_duration_us_global_->Set( + response_cache->TotalInsertionLatencyNs() / 1000); + cache_util_global_->Set(response_cache->TotalUtilization()); + return true; +} + +#ifdef TRITON_ENABLE_METRICS_CPU +Status +Metrics::ParseCpuInfo(CpuInfo& info) +{ +#ifdef _WIN32 + return Status( + Status::Code::INTERNAL, "CPU metrics not supported on Windows."); +#else + std::ifstream ifs("/proc/stat"); + if (!ifs.good()) { + return Status(Status::Code::INTERNAL, "Failed to open /proc/stat."); + } + + std::string line; + // Verify first line is aggregate cpu line + std::getline(ifs, line); + if (line.rfind("cpu ", 0) == std::string::npos) { + return Status( + Status::Code::INTERNAL, + "Failed to find aggregate CPU info in /proc/stat."); + } + + std::string _; + std::istringstream iss(line); + // Use _ to skip "cpu" at start of line + if (!(iss >> _ >> info)) { + return Status( + Status::Code::INTERNAL, + "Failed to parse aggregate CPU info in /proc/stat."); + } + return Status::Success; +#endif // OS +} + +Status +Metrics::ParseMemInfo(MemInfo& info) +{ +#ifdef _WIN32 + return Status( + Status::Code::INTERNAL, "Memory metrics not supported on Windows."); +#else + std::ifstream ifs("/proc/meminfo"); + if (!ifs.good()) { + return Status(Status::Code::INTERNAL, "Failed to open /proc/meminfo."); + } + + std::string line; + constexpr uint64_t KB = 1024; + while (std::getline(ifs, line)) { + std::istringstream iss(line); + std::string name; + uint64_t value = 0; + if (iss >> name >> value) { + name.pop_back(); + info[name] = value * KB; + } else { + return Status( + Status::Code::INTERNAL, "Encountered error parsing /proc/meminfo."); + } + } + + if (info.find("MemTotal") == info.end() || + info.find("MemAvailable") == info.end()) { + return Status( + Status::Code::INTERNAL, + "Failed to find desired values in /proc/meminfo."); + } + + if (info["MemAvailable"] > info["MemTotal"]) { + return Status( + Status::Code::INTERNAL, + "Available bytes shouldn't be greater than Total bytes"); + } + + // "Used" memory can be defined in many different ways. While many + // older applications consider "used = total - (free + cached)", a more + // accurate measure of available memory "MemAvailable" was added, + // so we choose "used = total - available" for a more accurate measure. + // This may change in the future if not sufficient for most use cases. + // See https://stackoverflow.com/a/35019697. + info["MemUsed"] = info["MemTotal"] - info["MemAvailable"]; + + return Status::Success; +#endif // OS +} + +double +Metrics::CpuUtilization(const CpuInfo& info_new, const CpuInfo& info_old) +{ + // Account for overflow + const auto wrap_sub = [](uint64_t a, uint64_t b) { + return (a > b) ? (a - b) : 0; + }; + uint64_t util_diff = wrap_sub(info_new.user, info_old.user) + + wrap_sub(info_new.nice, info_old.nice) + + wrap_sub(info_new.system, info_old.system) + + wrap_sub(info_new.irq, info_old.irq) + + wrap_sub(info_new.softirq, info_old.softirq) + + wrap_sub(info_new.steal, info_old.steal); + uint64_t idle_diff = wrap_sub(info_new.idle, info_old.idle) + + wrap_sub(info_new.iowait, info_old.iowait); + double util_ratio = static_cast(util_diff) / (util_diff + idle_diff); + return util_ratio; +} +#endif // TRITON_ENABLE_METRICS_CPU + +bool +Metrics::PollCpuMetrics() +{ +#ifndef TRITON_ENABLE_METRICS_CPU + return false; +#else + // CPU Utilization + double cpu_util = 0.0; + auto cpu_info = CpuInfo(); + auto status = ParseCpuInfo(cpu_info); + if (status.IsOk()) { + cpu_util = CpuUtilization(cpu_info, last_cpu_info_); + last_cpu_info_ = cpu_info; + } + cpu_utilization_->Set(cpu_util); // [0.0, 1.0] + + // RAM / Memory + double mem_total_bytes = 0.0; + double mem_used_bytes = 0.0; + auto mem_info = MemInfo(); + status = ParseMemInfo(mem_info); + if (status.IsOk()) { + // MemTotal will usually not change over time, but if something + // goes wrong when querying memory, we can reflect that by updating. + mem_total_bytes = mem_info["MemTotal"]; + mem_used_bytes = mem_info["MemUsed"]; + } + + cpu_memory_total_->Set(mem_total_bytes); + cpu_memory_used_->Set(mem_used_bytes); + + return true; +#endif // TRITON_ENABLE_METRICS_CPU +} + +bool +Metrics::PollDcgmMetrics() +{ +#ifndef TRITON_ENABLE_METRICS_GPU + return false; +#else + + if (dcgm_metadata_.available_cuda_gpu_ids_.size() == 0) { + LOG_WARNING << "error polling GPU metrics, GPU metrics will not be " + << "available: no available gpus to poll"; + return false; + } + + dcgmUpdateAllFields(dcgm_metadata_.dcgm_handle_, 1 /* wait for update*/); + for (unsigned int didx = 0; + didx < dcgm_metadata_.available_cuda_gpu_ids_.size(); ++didx) { + uint32_t cuda_id = dcgm_metadata_.available_cuda_gpu_ids_[didx]; + if (dcgm_metadata_.cuda_ids_to_dcgm_ids_.count(cuda_id) <= 0) { + LOG_WARNING << "Cannot find DCGM id for CUDA id " << cuda_id; + continue; + } + uint32_t dcgm_id = dcgm_metadata_.cuda_ids_to_dcgm_ids_.at(cuda_id); + dcgmFieldValue_v1 field_values[dcgm_metadata_.field_count_]; + dcgmReturn_t dcgmerr = dcgmGetLatestValuesForFields( + dcgm_metadata_.dcgm_handle_, dcgm_id, dcgm_metadata_.fields_.data(), + dcgm_metadata_.field_count_, field_values); + + if (dcgmerr != DCGM_ST_OK) { + dcgm_metadata_.power_limit_fail_cnt_[didx]++; + dcgm_metadata_.power_usage_fail_cnt_[didx]++; + dcgm_metadata_.energy_fail_cnt_[didx]++; + dcgm_metadata_.util_fail_cnt_[didx]++; + dcgm_metadata_.mem_fail_cnt_[didx]++; + LOG_WARNING << "Unable to get field values for GPU ID " << cuda_id << ": " + << errorString(dcgmerr); + } else { + // Power limit + if (dcgm_metadata_.power_limit_fail_cnt_[didx] < + dcgm_metadata_.fail_threshold_) { + double power_limit = field_values[0].value.dbl; + if ((field_values[0].status == DCGM_ST_OK) && + (!DCGM_FP64_IS_BLANK(power_limit))) { + dcgm_metadata_.power_limit_fail_cnt_[didx] = 0; + } else { + dcgm_metadata_.power_limit_fail_cnt_[didx]++; + power_limit = 0; + dcgmReturn_t status = dcgmReturn_t(field_values[0].status); + LOG_WARNING << "Unable to get power limit for GPU " << cuda_id + << ". Status:" << errorString(status) + << ", value:" << dcgmValueToErrorMessage(power_limit); + } + gpu_power_limit_[didx]->Set(power_limit); + } + + // Power usage + if (dcgm_metadata_.power_usage_fail_cnt_[didx] < + dcgm_metadata_.fail_threshold_) { + double power_usage = field_values[1].value.dbl; + if ((field_values[1].status == DCGM_ST_OK) && + (!DCGM_FP64_IS_BLANK(power_usage))) { + dcgm_metadata_.power_usage_fail_cnt_[didx] = 0; + } else { + dcgm_metadata_.power_usage_fail_cnt_[didx]++; + power_usage = 0; + dcgmReturn_t status = dcgmReturn_t(field_values[1].status); + LOG_WARNING << "Unable to get power usage for GPU " << cuda_id + << ". Status:" << errorString(status) + << ", value:" << dcgmValueToErrorMessage(power_usage); + } + gpu_power_usage_[didx]->Set(power_usage); + } + + // Energy Consumption + if (dcgm_metadata_.energy_fail_cnt_[didx] < + dcgm_metadata_.fail_threshold_) { + int64_t energy = field_values[2].value.i64; + if ((field_values[2].status == DCGM_ST_OK) && + (!DCGM_INT64_IS_BLANK(energy))) { + dcgm_metadata_.energy_fail_cnt_[didx] = 0; + if (dcgm_metadata_.last_energy_[didx] == 0) { + dcgm_metadata_.last_energy_[didx] = energy; + } + gpu_energy_consumption_[didx]->Increment( + (double)(energy - dcgm_metadata_.last_energy_[didx]) * 0.001); + dcgm_metadata_.last_energy_[didx] = energy; + } else { + dcgm_metadata_.energy_fail_cnt_[didx]++; + energy = 0; + dcgmReturn_t status = dcgmReturn_t(field_values[2].status); + LOG_WARNING << "Unable to get energy consumption for " + << "GPU " << cuda_id << ". Status:" << errorString(status) + << ", value:" << dcgmValueToErrorMessage(energy); + } + } + + // Utilization + if (dcgm_metadata_.util_fail_cnt_[didx] < + dcgm_metadata_.fail_threshold_) { + int64_t util = field_values[3].value.i64; + if ((field_values[3].status == DCGM_ST_OK) && + (!DCGM_INT64_IS_BLANK(util))) { + dcgm_metadata_.util_fail_cnt_[didx] = 0; + } else { + dcgm_metadata_.util_fail_cnt_[didx]++; + util = 0; + dcgmReturn_t status = dcgmReturn_t(field_values[3].status); + LOG_WARNING << "Unable to get GPU utilization for GPU " << cuda_id + << ". Status:" << errorString(status) + << ", value:" << dcgmValueToErrorMessage(util); + } + gpu_utilization_[didx]->Set((double)util * 0.01); + } + + // Memory Usage + if (dcgm_metadata_.mem_fail_cnt_[didx] < dcgm_metadata_.fail_threshold_) { + int64_t memory_used = field_values[4].value.i64; + int64_t memory_total = field_values[5].value.i64; + if ((field_values[4].status == DCGM_ST_OK) && + (!DCGM_INT64_IS_BLANK(memory_used)) && + (field_values[5].status == DCGM_ST_OK) && + (!DCGM_INT64_IS_BLANK(memory_total))) { + dcgm_metadata_.mem_fail_cnt_[didx] = 0; + } else { + memory_total = 0; + memory_used = 0; + dcgm_metadata_.mem_fail_cnt_[didx]++; + dcgmReturn_t usageStatus = dcgmReturn_t(field_values[4].status); + dcgmReturn_t memoryTotaltatus = dcgmReturn_t(field_values[5].status); + LOG_WARNING << "Unable to get memory usage for GPU " << cuda_id + << ". Memory usage status:" << errorString(usageStatus) + << ", value:" << dcgmValueToErrorMessage(memory_used) + << ". Memory total status:" + << errorString(memoryTotaltatus) + << ", value:" << dcgmValueToErrorMessage(memory_total); + } + gpu_memory_total_[didx]->Set(memory_total * 1024 * 1024); // bytes + gpu_memory_used_[didx]->Set(memory_used * 1024 * 1024); // bytes + } + } + } + return true; +#endif // TRITON_ENABLE_METRICS_GPU +} + +bool +Metrics::InitializeCacheMetrics( + std::shared_ptr response_cache) +{ + if (response_cache == nullptr) { + LOG_WARNING + << "error initializing cache metrics, cache metrics will not be " + << "available: cache was nullptr"; + return false; + } + + const std::map cache_labels; + cache_num_entries_global_ = &cache_num_entries_family_.Add(cache_labels); + cache_num_lookups_global_ = &cache_num_lookups_family_.Add(cache_labels); + cache_num_hits_global_ = &cache_num_hits_family_.Add(cache_labels); + cache_num_misses_global_ = &cache_num_misses_family_.Add(cache_labels); + cache_num_evictions_global_ = &cache_num_evictions_family_.Add(cache_labels); + cache_lookup_duration_us_global_ = + &cache_lookup_duration_us_family_.Add(cache_labels); + cache_insertion_duration_us_global_ = + &cache_insertion_duration_us_family_.Add(cache_labels); + cache_util_global_ = &cache_util_family_.Add(cache_labels); + LOG_INFO << "Collecting Response Cache metrics"; + return true; +} + +bool +Metrics::InitializeCpuMetrics() +{ +#ifndef TRITON_ENABLE_METRICS_CPU + return false; +#else + const std::map cpu_labels; + cpu_utilization_ = &cpu_utilization_family_.Add(cpu_labels); + cpu_memory_total_ = &cpu_memory_total_family_.Add(cpu_labels); + cpu_memory_used_ = &cpu_memory_used_family_.Add(cpu_labels); + + // Get baseline CPU info for future comparisons + last_cpu_info_ = CpuInfo(); + auto status = ParseCpuInfo(last_cpu_info_); + if (!status.IsOk()) { + LOG_WARNING << "error initializing CPU metrics, CPU utilization may not " + "be available: " + << status.Message(); + return false; + } + + // Verify memory metrics can be parsed + auto mem_info = MemInfo(); + status = ParseMemInfo(mem_info); + if (!status.IsOk()) { + LOG_WARNING << "error initializing CPU metrics, CPU memory metrics may not " + "be available: " + << status.Message(); + return false; + } + + LOG_INFO << "Collecting CPU metrics"; + return true; +#endif // TRITON_ENABLE_METRICS_CPU +} + +bool +Metrics::InitializeDcgmMetrics() +{ +#ifndef TRITON_ENABLE_METRICS_GPU + return false; +#else + dcgmReturn_t dcgmerr = dcgmInit(); + if (dcgmerr != DCGM_ST_OK) { + LOG_WARNING << "error initializing DCGM, GPU metrics will not be " + << "available: " << errorString(dcgmerr); + return false; + } + + if (dcgm_metadata_.standalone_) { + char hostIpAddress[16] = {0}; + std::string ipAddress = "127.0.0.1"; + strncpy(hostIpAddress, ipAddress.c_str(), 15); + dcgmerr = dcgmConnect(hostIpAddress, &dcgm_metadata_.dcgm_handle_); + } else { + dcgmerr = dcgmStartEmbedded( + DCGM_OPERATION_MODE_MANUAL, &dcgm_metadata_.dcgm_handle_); + } + if (dcgmerr != DCGM_ST_OK) { + LOG_WARNING << "DCGM unable to start: " << errorString(dcgmerr); + return false; + } else { + // Set this flag to signal DCGM cleanup in destructor + dcgm_metadata_.dcgm_initialized_ = true; + } + + if (dcgm_metadata_.standalone_) { + dcgmerr = dcgmUpdateAllFields(dcgm_metadata_.dcgm_handle_, 1); + if (dcgmerr != DCGM_ST_OK) { + LOG_WARNING << "DCGM unable to update all fields, GPU metrics will " + "not be available: " + << errorString(dcgmerr); + return false; + } + } + + unsigned int dcgm_gpu_ids[DCGM_MAX_NUM_DEVICES]; + int dcgm_gpu_count; + dcgmerr = dcgmGetAllDevices( + dcgm_metadata_.dcgm_handle_, dcgm_gpu_ids, &dcgm_gpu_count); + if (dcgmerr != DCGM_ST_OK) { + LOG_WARNING << "DCGM unable to get device info and count, GPU " + "metrics will not be available: " + << errorString(dcgmerr); + return false; + } + + // Get PCI Bus ID to DCGM device Id map. + // Some devices may have problems using DCGM API and + // these devices needs to be ignored. + std::map pci_bus_id_to_dcgm_id; + std::map > + pci_bus_id_to_gpu_labels; + std::map pci_bus_id_to_device_name; + dcgmDeviceAttributes_t gpu_attributes[DCGM_MAX_NUM_DEVICES]; + for (int i = 0; i < dcgm_gpu_count; i++) { + gpu_attributes[i].version = dcgmDeviceAttributes_version; + dcgmerr = dcgmGetDeviceAttributes( + dcgm_metadata_.dcgm_handle_, dcgm_gpu_ids[i], &gpu_attributes[i]); + if (dcgmerr != DCGM_ST_OK) { + LOG_WARNING << "DCGM unable to get device properties for DCGM device " + << dcgm_gpu_ids[i] + << ", GPU metrics will not be available for this device: " + << errorString(dcgmerr); + } else { + std::string pciBusId = gpu_attributes[i].identifiers.pciBusId; + pci_bus_id_to_dcgm_id[pciBusId] = i; + pci_bus_id_to_device_name[pciBusId] = + std::string(gpu_attributes[i].identifiers.deviceName); + std::map gpu_labels; + gpu_labels.insert(std::map::value_type( + kMetricsLabelGpuUuid, + std::string(gpu_attributes[i].identifiers.uuid))); + pci_bus_id_to_gpu_labels[pciBusId] = gpu_labels; + } + } + + + // Get CUDA-visible PCI Bus Ids and get DCGM metrics for each CUDA-visible GPU + int cuda_gpu_count; + cudaError_t cudaerr = cudaGetDeviceCount(&cuda_gpu_count); + if (cudaerr != cudaSuccess) { + LOG_WARNING + << "Cannot get CUDA device count, GPU metrics will not be available"; + return false; + } + for (int i = 0; i < cuda_gpu_count; ++i) { + std::string pci_bus_id = "0000"; // pad 0's for uniformity + char pcibusid_str[64]; + cudaerr = cudaDeviceGetPCIBusId(pcibusid_str, sizeof(pcibusid_str) - 1, i); + if (cudaerr == cudaSuccess) { + pci_bus_id.append(pcibusid_str); + if (pci_bus_id_to_dcgm_id.count(pci_bus_id) <= 0) { + LOG_INFO << "Skipping GPU:" << i + << " since it's not CUDA enabled. This should never happen!"; + continue; + } + // Filter out CUDA visible GPUs from GPUs found by DCGM + LOG_INFO << "Collecting metrics for GPU " << i << ": " + << pci_bus_id_to_device_name[pci_bus_id]; + auto& gpu_labels = pci_bus_id_to_gpu_labels[pci_bus_id]; + gpu_utilization_.push_back(&gpu_utilization_family_.Add(gpu_labels)); + gpu_memory_total_.push_back(&gpu_memory_total_family_.Add(gpu_labels)); + gpu_memory_used_.push_back(&gpu_memory_used_family_.Add(gpu_labels)); + gpu_power_usage_.push_back(&gpu_power_usage_family_.Add(gpu_labels)); + gpu_power_limit_.push_back(&gpu_power_limit_family_.Add(gpu_labels)); + gpu_energy_consumption_.push_back( + &gpu_energy_consumption_family_.Add(gpu_labels)); + uint32_t dcgm_id = pci_bus_id_to_dcgm_id[pci_bus_id]; + dcgm_metadata_.cuda_ids_to_dcgm_ids_[i] = dcgm_id; + dcgm_metadata_.available_cuda_gpu_ids_.emplace_back(i); + } else { + LOG_WARNING << "GPU metrics will not be available for device:" << i; + } + } + + // create a gpu group + char groupName[] = "dcgm_group"; + dcgmerr = dcgmGroupCreate( + dcgm_metadata_.dcgm_handle_, DCGM_GROUP_DEFAULT, groupName, + &dcgm_metadata_.groupId_); + if (dcgmerr != DCGM_ST_OK) { + LOG_WARNING << "Cannot make GPU group: " << errorString(dcgmerr); + } + + // Initialize tracking vectors + for (unsigned int didx = 0; + didx < dcgm_metadata_.available_cuda_gpu_ids_.size(); ++didx) { + dcgm_metadata_.power_limit_fail_cnt_.push_back(0); + dcgm_metadata_.power_usage_fail_cnt_.push_back(0); + dcgm_metadata_.energy_fail_cnt_.push_back(0); + dcgm_metadata_.util_fail_cnt_.push_back(0); + dcgm_metadata_.mem_fail_cnt_.push_back(0); + dcgm_metadata_.last_energy_.push_back(0); + } + + // Number of fields for DCGM to use from fields_ below + dcgm_metadata_.field_count_ = 6; + unsigned short util_flag = dcgm_metadata_.standalone_ + ? DCGM_FI_PROF_GR_ENGINE_ACTIVE + : DCGM_FI_DEV_GPU_UTIL; + dcgm_metadata_.fields_ = { + DCGM_FI_DEV_POWER_MGMT_LIMIT, // power limit, watts + DCGM_FI_DEV_POWER_USAGE, // power usage, watts + DCGM_FI_DEV_TOTAL_ENERGY_CONSUMPTION, // Total energy consumption, mJ + util_flag, // util ratio, 1 = 1% + DCGM_FI_DEV_FB_USED, // Frame buffer used, MiB + DCGM_FI_DEV_FB_TOTAL, // Frame buffer used, MiB + }; + + char fieldName[] = "field_group"; + dcgmFieldGrp_t fieldGroupId; + dcgmerr = dcgmFieldGroupCreate( + dcgm_metadata_.dcgm_handle_, dcgm_metadata_.field_count_, + dcgm_metadata_.fields_.data(), fieldName, &fieldGroupId); + if (dcgmerr != DCGM_ST_OK) { + LOG_WARNING << "Cannot make field group: " << errorString(dcgmerr); + } + + dcgmerr = dcgmWatchFields( + dcgm_metadata_.dcgm_handle_, dcgm_metadata_.groupId_, fieldGroupId, + metrics_interval_ms_ * 1000 /*update period, usec*/, + 5.0 /*maxKeepAge, sec*/, 5 /*maxKeepSamples*/); + if (dcgmerr != DCGM_ST_OK) { + LOG_WARNING << "Cannot start watching fields: " << errorString(dcgmerr); + return false; + } + + return true; +#endif // TRITON_ENABLE_METRICS_GPU +} + +#ifdef TRITON_ENABLE_METRICS_GPU +std::string +Metrics::dcgmValueToErrorMessage(double val) +{ + if (DCGM_FP64_IS_BLANK(val)) { + if (val == DCGM_FP64_BLANK) { + return "Not Specified"; + } else if (val == DCGM_FP64_NOT_FOUND) { + return "Not Found"; + } else if (val == DCGM_FP64_NOT_SUPPORTED) { + return "Not Supported"; + } else if (val == DCGM_FP64_NOT_PERMISSIONED) { + return "Insf. Permission"; + } else { + return "Unknown"; + } + } else { + return std::to_string(val); + } +} + +std::string +Metrics::dcgmValueToErrorMessage(int64_t val) +{ + if (DCGM_INT64_IS_BLANK(val)) { + switch (val) { + case DCGM_INT64_BLANK: + return "Not Specified"; + case DCGM_INT64_NOT_FOUND: + return "Not Found"; + case DCGM_INT64_NOT_SUPPORTED: + return "Not Supported"; + case DCGM_INT64_NOT_PERMISSIONED: + return "Insf. Permission"; + default: + return "Unknown"; + } + } else { + return std::to_string(val); + } +} +#endif // TRITON_ENABLE_METRICS_GPU + +bool +Metrics::UUIDForCudaDevice(int cuda_device, std::string* uuid) +{ + // If metrics were not initialized then just silently fail since + // with DCGM we can't get the CUDA device (and not worth doing + // anyway since metrics aren't being reported). + auto singleton = GetSingleton(); + if (!singleton->gpu_metrics_enabled_) { + return false; + } + + // If GPU metrics is not enabled just silently fail. +#ifndef TRITON_ENABLE_METRICS_GPU + return false; +#else + + dcgmDeviceAttributes_t gpu_attributes; + gpu_attributes.version = dcgmDeviceAttributes_version; + dcgmReturn_t dcgmerr = dcgmGetDeviceAttributes( + singleton->dcgm_metadata_.dcgm_handle_, cuda_device, &gpu_attributes); + if (dcgmerr != DCGM_ST_OK) { + LOG_ERROR << "Unable to get device UUID: " << errorString(dcgmerr); + return false; + } + + *uuid = gpu_attributes.identifiers.uuid; + return true; +#endif // TRITON_ENABLE_METRICS_GPU +} + +std::shared_ptr +Metrics::GetRegistry() +{ + auto singleton = Metrics::GetSingleton(); + return singleton->registry_; +} + +const std::string +Metrics::SerializedMetrics() +{ + auto singleton = Metrics::GetSingleton(); + return singleton->serializer_->Serialize( + singleton->registry_.get()->Collect()); +} + +Metrics* +Metrics::GetSingleton() +{ + static Metrics singleton; + return &singleton; +} + +}} // namespace triton::core + +#endif // TRITON_ENABLE_METRICS diff --git a/3rdparty/core-r22.12/src/metrics.h b/3rdparty/core-r22.12/src/metrics.h new file mode 100644 index 0000000000000000000000000000000000000000..9b7e8f4a168f3def67cc91d121f39e15178fd6d2 --- /dev/null +++ b/3rdparty/core-r22.12/src/metrics.h @@ -0,0 +1,335 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +#pragma once + +#ifdef TRITON_ENABLE_METRICS + +#include +#include +#include +#include "prometheus/counter.h" +#include "prometheus/gauge.h" +#include "prometheus/registry.h" +#include "prometheus/serializer.h" +#include "prometheus/text_serializer.h" +#include "response_cache.h" + +#ifdef TRITON_ENABLE_METRICS_GPU +#include +#endif // TRITON_ENABLE_METRICS_GPU + +namespace triton { namespace core { + +#ifdef TRITON_ENABLE_METRICS_CPU +using MemInfo = std::unordered_map; + +// References: +// - htop source: https://stackoverflow.com/a/23376195 +// - Linux docs: https://www.kernel.org/doc/Documentation/filesystems/proc.txt +// guest/guestnice values are counted in user/nice so we skip parsing them +struct CpuInfo { + uint64_t user = 0; // normal processes executing in user mode + uint64_t nice = 0; // niced processes executing in user mode + uint64_t system = 0; // processes executing in kernel mode + uint64_t idle = 0; // twiddling thumbs + uint64_t iowait = 0; // waiting for I/O to complete + uint64_t irq = 0; // servicing interrupts + uint64_t softirq = 0; // servicing softirqs + uint64_t steal = 0; // involuntary wait +}; + +inline std::istream& +operator>>(std::istream& is, CpuInfo& info) +{ + is >> info.user >> info.nice >> info.system >> info.idle >> info.iowait >> + info.irq >> info.softirq >> info.steal; + return is; +} +#endif // TRITON_ENABLE_METRICS_CPU + +#ifdef TRITON_ENABLE_METRICS_GPU +struct DcgmMetadata { + // DCGM handles for initialization and destruction + dcgmHandle_t dcgm_handle_ = 0; + dcgmGpuGrp_t groupId_ = 0; + // DCGM Flags + bool standalone_ = false; + // DCGM Fields + size_t field_count_ = 0; + std::vector fields_; + // GPU Device Mapping + std::map cuda_ids_to_dcgm_ids_; + std::vector available_cuda_gpu_ids_; + // Stop attempting metrics if they fail multiple consecutive + // times for a device. + const int fail_threshold_ = 3; + // DCGM Failure Tracking + std::vector power_limit_fail_cnt_; + std::vector power_usage_fail_cnt_; + std::vector energy_fail_cnt_; + std::vector util_fail_cnt_; + std::vector mem_fail_cnt_; + // DCGM Energy Tracking + std::vector last_energy_; + // Track if DCGM handle initialized successfully + bool dcgm_initialized_ = false; +}; +#endif // TRITON_ENABLE_METRICS_GPU + +class Metrics { + public: + // Return the hash value of the labels + static size_t HashLabels(const std::map& labels); + + // Are metrics enabled? + static bool Enabled(); + + // Enable reporting of metrics + static void EnableMetrics(); + + // Enable reporting of GPU metrics + static void EnableGPUMetrics(); + + // Enable reporting of CPU metrics + static void EnableCpuMetrics(); + + // Enable reporting of Cache metrics + static void EnableCacheMetrics( + std::shared_ptr response_cache); + + // Start a thread for polling enabled metrics if any + static void StartPollingThreadSingleton( + std::shared_ptr response_cache); + + // Set the time interval in secs at which metrics are collected + static void SetMetricsInterval(uint64_t metrics_interval_ms); + + // Get the prometheus registry + static std::shared_ptr GetRegistry(); + + // Get serialized metrics + static const std::string SerializedMetrics(); + + // Get the UUID for a CUDA device. Return true and initialize 'uuid' + // if a UUID is found, return false if a UUID cannot be returned. + static bool UUIDForCudaDevice(int cuda_device, std::string* uuid); + + // Metric family counting successful inference requests + static prometheus::Family& FamilyInferenceSuccess() + { + return GetSingleton()->inf_success_family_; + } + + // Metric family counting failed inference requests + static prometheus::Family& FamilyInferenceFailure() + { + return GetSingleton()->inf_failure_family_; + } + + // Metric family counting inferences performed, where a batch-size + // 'n' inference request is counted as 'n' inferences + static prometheus::Family& FamilyInferenceCount() + { + return GetSingleton()->inf_count_family_; + } + + // Metric family counting inferences performed, where a batch-size + // 'n' inference request is counted as 'n' inferences + static prometheus::Family& + FamilyInferenceExecutionCount() + { + return GetSingleton()->inf_count_exec_family_; + } + + // Metric family of cumulative inference request duration, in + // microseconds + static prometheus::Family& + FamilyInferenceRequestDuration() + { + return GetSingleton()->inf_request_duration_us_family_; + } + + // Metric family of cumulative inference queuing duration, in + // microseconds + static prometheus::Family& FamilyInferenceQueueDuration() + { + return GetSingleton()->inf_queue_duration_us_family_; + } + + // Metric family of cumulative inference compute durations, in + // microseconds + static prometheus::Family& + FamilyInferenceComputeInputDuration() + { + return GetSingleton()->inf_compute_input_duration_us_family_; + } + static prometheus::Family& + FamilyInferenceComputeInferDuration() + { + return GetSingleton()->inf_compute_infer_duration_us_family_; + } + static prometheus::Family& + FamilyInferenceComputeOutputDuration() + { + return GetSingleton()->inf_compute_output_duration_us_family_; + } + // Metric families of per-model response cache metrics + static prometheus::Family& FamilyCacheHitCount() + { + return GetSingleton()->cache_num_hits_model_family_; + } + static prometheus::Family& FamilyCacheHitLookupDuration() + { + return GetSingleton()->cache_hit_lookup_duration_us_model_family_; + } + static prometheus::Family& FamilyCacheMissCount() + { + return GetSingleton()->cache_num_misses_model_family_; + } + static prometheus::Family& + FamilyCacheMissLookupDuration() + { + return GetSingleton()->cache_miss_lookup_duration_us_model_family_; + } + static prometheus::Family& + FamilyCacheMissInsertionDuration() + { + return GetSingleton()->cache_miss_insertion_duration_us_model_family_; + } + + + private: + Metrics(); + virtual ~Metrics(); + static Metrics* GetSingleton(); + bool InitializeDcgmMetrics(); + bool InitializeCpuMetrics(); + bool InitializeCacheMetrics( + std::shared_ptr response_cache); + bool StartPollingThread(std::shared_ptr response_cache); + bool PollCacheMetrics(std::shared_ptr response_cache); + bool PollDcgmMetrics(); + bool PollCpuMetrics(); + + std::string dcgmValueToErrorMessage(double val); + std::string dcgmValueToErrorMessage(int64_t val); + + std::shared_ptr registry_; + std::unique_ptr serializer_; + + prometheus::Family& inf_success_family_; + prometheus::Family& inf_failure_family_; + prometheus::Family& inf_count_family_; + prometheus::Family& inf_count_exec_family_; + prometheus::Family& inf_request_duration_us_family_; + prometheus::Family& inf_queue_duration_us_family_; + prometheus::Family& + inf_compute_input_duration_us_family_; + prometheus::Family& + inf_compute_infer_duration_us_family_; + prometheus::Family& + inf_compute_output_duration_us_family_; + // Global Response Cache metrics + prometheus::Family& cache_num_entries_family_; + prometheus::Family& cache_num_lookups_family_; + prometheus::Family& cache_num_hits_family_; + prometheus::Family& cache_num_misses_family_; + prometheus::Family& cache_num_evictions_family_; + prometheus::Family& cache_lookup_duration_us_family_; + prometheus::Family& cache_insertion_duration_us_family_; + prometheus::Family& cache_util_family_; + // Gauges for Global Response Cache metrics + prometheus::Gauge* cache_num_entries_global_; + prometheus::Gauge* cache_num_lookups_global_; + prometheus::Gauge* cache_num_hits_global_; + prometheus::Gauge* cache_num_misses_global_; + prometheus::Gauge* cache_num_evictions_global_; + prometheus::Gauge* cache_lookup_duration_us_global_; + prometheus::Gauge* cache_insertion_duration_us_global_; + prometheus::Gauge* cache_util_global_; + // Per-model Response Cache metrics + prometheus::Family& cache_num_hits_model_family_; + prometheus::Family& + cache_hit_lookup_duration_us_model_family_; + prometheus::Family& cache_num_misses_model_family_; + prometheus::Family& + cache_miss_lookup_duration_us_model_family_; + prometheus::Family& + cache_miss_insertion_duration_us_model_family_; + +#ifdef TRITON_ENABLE_METRICS_GPU + prometheus::Family& gpu_utilization_family_; + prometheus::Family& gpu_memory_total_family_; + prometheus::Family& gpu_memory_used_family_; + prometheus::Family& gpu_power_usage_family_; + prometheus::Family& gpu_power_limit_family_; + prometheus::Family& gpu_energy_consumption_family_; + + std::vector gpu_utilization_; + std::vector gpu_memory_total_; + std::vector gpu_memory_used_; + std::vector gpu_power_usage_; + std::vector gpu_power_limit_; + std::vector gpu_energy_consumption_; + + DcgmMetadata dcgm_metadata_; +#endif // TRITON_ENABLE_METRICS_GPU + +#ifdef TRITON_ENABLE_METRICS_CPU + // Parses "/proc/meminfo" for metrics, currently only supported on Linux. + Status ParseMemInfo(MemInfo& info); + // Parses "/proc/stat" for metrics, currently only supported on Linux. + Status ParseCpuInfo(CpuInfo& info); + // Computes CPU utilization between "info_new" and "info_old" values + double CpuUtilization(const CpuInfo& info_new, const CpuInfo& info_old); + + prometheus::Family& cpu_utilization_family_; + prometheus::Family& cpu_memory_total_family_; + prometheus::Family& cpu_memory_used_family_; + + prometheus::Gauge* cpu_utilization_; + prometheus::Gauge* cpu_memory_total_; + prometheus::Gauge* cpu_memory_used_; + CpuInfo last_cpu_info_; +#endif // TRITON_ENABLE_METRICS_CPU + + // Thread for polling cache/gpu metrics periodically + std::unique_ptr poll_thread_; + std::atomic poll_thread_exit_; + bool metrics_enabled_; + bool gpu_metrics_enabled_; + bool cpu_metrics_enabled_; + bool cache_metrics_enabled_; + bool poll_thread_started_; + std::mutex metrics_enabling_; + std::mutex poll_thread_starting_; + uint64_t metrics_interval_ms_; +}; + +}} // namespace triton::core + +#endif // TRITON_ENABLE_METRICS diff --git a/3rdparty/core-r22.12/src/model.cc b/3rdparty/core-r22.12/src/model.cc new file mode 100644 index 0000000000000000000000000000000000000000..c59a3170c7319b926fab9dbaa95a3faeda768816 --- /dev/null +++ b/3rdparty/core-r22.12/src/model.cc @@ -0,0 +1,137 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "model.h" + +#include +#include +#include "constants.h" +#include "filesystem.h" +#include "infer_request.h" +#include "model_config_utils.h" +#include "triton/common/logging.h" + +namespace triton { namespace core { + +Status +Model::GetInput( + const std::string& name, const inference::ModelInput** input) const +{ + const auto itr = input_map_.find(name); + if (itr == input_map_.end()) { + return Status( + Status::Code::INVALID_ARG, + "unexpected inference input '" + name + "' for model '" + Name() + "'"); + } + + *input = &itr->second; + return Status::Success; +} + +Status +Model::GetOutput( + const std::string& name, const inference::ModelOutput** output) const +{ + const auto itr = output_map_.find(name); + if (itr == output_map_.end()) { + return Status( + Status::Code::INVALID_ARG, "unexpected inference output '" + name + + "' for model '" + Name() + "'"); + } + + *output = &itr->second; + return Status::Success; +} + +Status +Model::SetModelConfig(const inference::ModelConfig& config) +{ + config_ = config; + set_model_config_ = true; + + return Status::Success; +} + +Status +Model::SetScheduler(std::unique_ptr scheduler) +{ + if (scheduler_ != nullptr) { + return Status( + Status::Code::INTERNAL, "Attempt to change scheduler not allowed"); + } + + scheduler_ = std::move(scheduler); + return Status::Success; +} + +Status +Model::Init(const bool is_config_provided) +{ + if (!set_model_config_ && !is_config_provided) { + return Status( + Status::Code::NOT_FOUND, + "model configuration is not provided for model '" + Name() + "'"); + } + + RETURN_IF_ERROR(ValidateModelConfig(config_, min_compute_capability_)); + RETURN_IF_ERROR(ValidateModelIOConfig(config_)); + + // Initialize the input map + for (const auto& io : config_.input()) { + input_map_.insert(std::make_pair(io.name(), io)); + if (!io.optional()) { + ++required_input_count_; + } + } + + // Initialize the output map and label provider for each output + label_provider_ = std::make_shared(); + for (const auto& io : config_.output()) { + output_map_.insert(std::make_pair(io.name(), io)); + + if (!io.label_filename().empty()) { + const auto label_path = JoinPath({model_dir_, io.label_filename()}); + RETURN_IF_ERROR(label_provider_->AddLabels(io.name(), label_path)); + } + } + + if (config_.has_dynamic_batching()) { + default_priority_level_ = + config_.dynamic_batching().default_priority_level(); + max_priority_level_ = config_.dynamic_batching().priority_levels(); + } else if (config_.has_ensemble_scheduling()) { + // For ensemble, allow any priority level to pass through + default_priority_level_ = 0; + max_priority_level_ = UINT32_MAX; + } else { + default_priority_level_ = 0; + max_priority_level_ = 0; + } + + return Status::Success; +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/model.h b/3rdparty/core-r22.12/src/model.h new file mode 100644 index 0000000000000000000000000000000000000000..240849856bf3bf0e5dc1c64fe42c40b17adc0b61 --- /dev/null +++ b/3rdparty/core-r22.12/src/model.h @@ -0,0 +1,162 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include "infer_stats.h" +#include "label_provider.h" +#include "model_config.pb.h" +#include "scheduler.h" +#include "status.h" + +namespace triton { namespace core { + +class InferenceRequest; + +// +// Interface for models that handle inference requests. +// +class Model { + public: + explicit Model( + const double min_compute_capability, const std::string& model_dir, + const int64_t version, const inference::ModelConfig& config) + : config_(config), min_compute_capability_(min_compute_capability), + version_(version), required_input_count_(0), model_dir_(model_dir), + set_model_config_(false) + { + } + virtual ~Model() {} + + // Get the name of model being served. + const std::string& Name() const { return config_.name(); } + + // Get the version of model being served. + int64_t Version() const { return version_; } + + // Get the configuration of model being served. + const inference::ModelConfig& Config() const { return config_; } + + // Get the number of required inputs + size_t RequiredInputCount() const { return required_input_count_; } + + // Get the stats collector for the model being served. + InferenceStatsAggregator* MutableStatsAggregator() + { + return &stats_aggregator_; + } + const InferenceStatsAggregator& StatsAggregator() const + { + return stats_aggregator_; + } + + // Get the model configuration for a named input. + Status GetInput( + const std::string& name, const inference::ModelInput** input) const; + + // Get the model configuration for a named output. + Status GetOutput( + const std::string& name, const inference::ModelOutput** output) const; + + // Get a label provider for the model. + const std::shared_ptr& GetLabelProvider() const + { + return label_provider_; + } + + // Initialize the instance for Triton core usage + Status Init(const bool is_config_provided); + + // Enqueue a request for execution. If Status::Success is returned + // then the model has taken ownership of the request object and so + // 'request' will be nullptr. If non-success is returned then the + // caller still retains ownership of 'request'. + Status Enqueue(std::unique_ptr& request) + { + return scheduler_->Enqueue(request); + } + + // Return the number of in-flight inferences. + size_t InflightInferenceCount() + { + return scheduler_->InflightInferenceCount(); + } + + // Stop processing future requests unless they are considered as in-flight. + void Stop() { scheduler_->Stop(); } + + uint32_t DefaultPriorityLevel() const { return default_priority_level_; } + + uint32_t MaxPriorityLevel() const { return max_priority_level_; } + + protected: + // Set the configuration of the model being served. + Status SetModelConfig(const inference::ModelConfig& config); + + // Explicitly set the scheduler to use for inference requests to the + // model. The scheduler can only be set once for a model. + Status SetScheduler(std::unique_ptr scheduler); + + // The scheduler to use for this model. + std::unique_ptr scheduler_; + + // Configuration of the model. + inference::ModelConfig config_; + + private: + // The minimum supported CUDA compute capability. + const double min_compute_capability_; + + // Version of the model. + int64_t version_; + + // The stats collector for the model. + InferenceStatsAggregator stats_aggregator_; + + // Label provider for this model. + std::shared_ptr label_provider_; + + size_t required_input_count_; + + // Map from input name to the model configuration for that input. + std::unordered_map input_map_; + + // Map from output name to the model configuration for that output. + std::unordered_map output_map_; + + // Path to model + std::string model_dir_; + + // The default priority level for the model. + uint32_t default_priority_level_; + + // The largest priority value for the model. + uint32_t max_priority_level_; + + // Whether or not model config has been set. + bool set_model_config_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/model_config_cuda.cc b/3rdparty/core-r22.12/src/model_config_cuda.cc new file mode 100644 index 0000000000000000000000000000000000000000..e08dfb499db1396829ffc9d2f9be0e380757b5ca --- /dev/null +++ b/3rdparty/core-r22.12/src/model_config_cuda.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "model_config_cuda.h" + +#include + +namespace triton { namespace core { + +int +GetCudaStreamPriority( + inference::ModelOptimizationPolicy::ModelPriority priority) +{ + // Default priority is 0 + int cuda_stream_priority = 0; + + int min, max; + cudaError_t cuerr = cudaDeviceGetStreamPriorityRange(&min, &max); + if ((cuerr != cudaErrorNoDevice) && (cuerr != cudaSuccess)) { + return 0; + } + + switch (priority) { + case inference::ModelOptimizationPolicy::PRIORITY_MAX: + cuda_stream_priority = max; + break; + case inference::ModelOptimizationPolicy::PRIORITY_MIN: + cuda_stream_priority = min; + break; + default: + cuda_stream_priority = 0; + break; + } + + return cuda_stream_priority; +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/model_config_cuda.h b/3rdparty/core-r22.12/src/model_config_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..f939232a312bfab372ecfcbf644ae9c496ce8525 --- /dev/null +++ b/3rdparty/core-r22.12/src/model_config_cuda.h @@ -0,0 +1,40 @@ +// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include "model_config.pb.h" + +namespace triton { namespace core { + +/// Get the CUDA stream priority for a given ModelPriority +/// \param priority The inference::ModelOptimizationPolicy::ModelPriority +/// priority. \param cuda_stream_priority Returns the CUDA stream priority. +/// \return The error status. +int GetCudaStreamPriority( + inference::ModelOptimizationPolicy::ModelPriority priority); + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/model_config_utils.cc b/3rdparty/core-r22.12/src/model_config_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..189b2df23e34e0a95f4a502acd6fea41fa4d7b22 --- /dev/null +++ b/3rdparty/core-r22.12/src/model_config_utils.cc @@ -0,0 +1,2294 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "model_config_utils.h" + +#include +#include +#include +#include +#include "constants.h" +#include "cuda_utils.h" +#include "filesystem.h" +#include "triton/common/logging.h" + +#define TRITONJSON_STATUSTYPE triton::core::Status +#define TRITONJSON_STATUSRETURN(M) \ + return triton::core::Status(triton::core::Status::Code::INTERNAL, (M)) +#define TRITONJSON_STATUSSUCCESS triton::core::Status::Success +#include "triton/common/triton_json.h" + +#ifdef TRITON_ENABLE_GPU +#include +#endif // TRITON_ENABLE_GPU + +namespace triton { namespace core { + +namespace { + +#ifdef TRITON_ENABLE_ENSEMBLE + +struct EnsembleTensor { + EnsembleTensor(bool isOutput) : ready(false), isOutput(isOutput) {} + bool ready; + bool isOutput; + std::vector prev_nodes; + std::vector next_nodes; +}; + +/// Build a graph that represents the data flow in the ensemble specified in +/// given model config. the node (ensemble tensor) in the graph can be looked +/// up using its name as key. +/// \param ensemble_config The model configuration that specifies +/// ensemble_scheduling field. +/// \param keyed_ensemble_graph Returned the ensemble graph. +/// \return The error status. A non-OK status indicates the build fails because +/// the ensemble configuration is not valid. +Status +BuildEnsembleGraph( + const inference::ModelConfig& config, + std::unordered_map& keyed_ensemble_graph) +{ + keyed_ensemble_graph.clear(); + size_t step_idx = 0; + for (const auto& element : config.ensemble_scheduling().step()) { + if (element.model_name().empty()) { + return Status( + Status::Code::INVALID_ARG, + "must specify 'model_name' in step " + std::to_string(step_idx) + + " of ensemble '" + config.name() + "'"); + } + if (element.input_map().size() == 0) { + return Status( + Status::Code::INVALID_ARG, + "must specify 'input_map' in step " + std::to_string(step_idx) + + " of ensemble '" + config.name() + "'"); + } + if (element.output_map().size() == 0) { + return Status( + Status::Code::INVALID_ARG, + "must specify 'output_map' in step " + std::to_string(step_idx) + + " of ensemble '" + config.name() + "'"); + } + + // Link ensemble tensors + std::vector tensor_as_output; + for (const auto& output_map : element.output_map()) { + auto it = keyed_ensemble_graph.find(output_map.second); + if (it != keyed_ensemble_graph.end()) { + if (it->second.isOutput) { + return Status( + Status::Code::INVALID_ARG, + "ensemble tensor '" + it->first + + "' can appear in an output map only once for ensemble '" + + config.name() + "' step " + std::to_string(step_idx)); + } else { + it->second.isOutput = true; + } + } else { + it = keyed_ensemble_graph + .emplace( + std::make_pair(output_map.second, EnsembleTensor(true))) + .first; + } + tensor_as_output.push_back(&(it->second)); + } + + std::set model_inputs; + for (const auto& input_map : element.input_map()) { + if (model_inputs.find(input_map.first) != model_inputs.end()) { + return Status( + Status::Code::INVALID_ARG, + "input '" + input_map.first + "' in model '" + + element.model_name() + + "' is mapped to multiple ensemble tensors for ensemble '" + + config.name() + "' step " + std::to_string(step_idx)); + } else { + model_inputs.emplace(input_map.first); + } + auto it = keyed_ensemble_graph.find(input_map.second); + if (it == keyed_ensemble_graph.end()) { + it = keyed_ensemble_graph + .emplace( + std::make_pair(input_map.second, EnsembleTensor(false))) + .first; + } + for (auto output : tensor_as_output) { + output->prev_nodes.push_back(&(it->second)); + it->second.next_nodes.push_back(output); + } + } + + step_idx++; + } + + return Status::Success; +} + +Status +ValidateEnsembleSchedulingConfig(const inference::ModelConfig& config) +{ + if (config.platform() != kEnsemblePlatform) { + return Status( + Status::Code::INVALID_ARG, + "ensemble scheduling cannot be set for model '" + config.name() + + "' whose platform is not " + kEnsemblePlatform); + } + if (config.instance_group().size() != 0) { + return Status( + Status::Code::INVALID_ARG, + "instance group should not be specified for ensemble '" + + config.name() + "'"); + } + if (config.has_optimization()) { + return Status( + Status::Code::INVALID_ARG, + "optimization should not be specified for ensemble '" + config.name() + + "'"); + } + if (config.model_warmup_size() != 0) { + return Status( + Status::Code::INVALID_ARG, + "model_warmup can not be specified for ensemble '" + config.name() + + "'"); + } + + // Make sure step is not empty and all fields are set + if (config.ensemble_scheduling().step_size() == 0) { + return Status( + Status::Code::INVALID_ARG, + "must specify 'step' for ensemble '" + config.name() + "'"); + } + + std::unordered_map tensors; + + RETURN_IF_ERROR(BuildEnsembleGraph(config, tensors)); + + // check data flow + std::deque ready_queue; + for (const auto& input : config.input()) { + auto it = tensors.find(input.name()); + if (it == tensors.end()) { + return Status( + Status::Code::INVALID_ARG, "ensemble input '" + input.name() + + "' for ensemble " + config.name() + + "' is not used"); + } + it->second.ready = true; + ready_queue.push_back(&(it->second)); + } + while (!ready_queue.empty()) { + auto& ready_node = ready_queue.front(); + for (auto& next_node : ready_node->next_nodes) { + if (next_node->ready) { + continue; + } + bool next_node_ready = true; + for (auto& prev_node : next_node->prev_nodes) { + if (!prev_node->ready) { + next_node_ready = false; + break; + } + } + next_node->ready = next_node_ready; + if (next_node_ready) { + ready_queue.push_back(next_node); + } + } + ready_queue.pop_front(); + } + std::set outputs; + for (const auto& output : config.output()) { + auto it = tensors.find(output.name()); + if (it == tensors.end()) { + return Status( + Status::Code::INVALID_ARG, "ensemble output '" + output.name() + + "' for ensemble " + config.name() + + "' is not used"); + } + if (!it->second.ready) { + return Status( + Status::Code::INVALID_ARG, "output '" + output.name() + + "' for ensemble '" + config.name() + + "' is not written"); + } else { + outputs.insert(it->first); + } + } + // Check redundant ensemble tensors + for (const auto& tensor : tensors) { + // skip ensemble outputs as they have been checked and can have no + // next nodes + if (outputs.find(tensor.first) != outputs.end()) { + continue; + } + if (!tensor.second.ready || (tensor.second.next_nodes.size() == 0)) { + return Status( + Status::Code::INVALID_ARG, "ensemble tensor '" + tensor.first + + "' is unused in ensemble '" + + config.name() + "'"); + } + } + return Status::Success; +} + +#endif // TRITON_ENABLE_ENSEMBLE + +template +Status +ValidateIOShape( + const ModelIO& io, int32_t max_batch_size, + const std::string& message_prefix = "") +{ + if (io.name().empty()) { + return Status( + Status::Code::INVALID_ARG, message_prefix + "must specify 'name'"); + } + + if (io.data_type() == inference::DataType::TYPE_INVALID) { + return Status( + Status::Code::INVALID_ARG, "model output must specify 'data_type'"); + } + + if (io.dims_size() == 0) { + return Status( + Status::Code::INVALID_ARG, message_prefix + "must specify 'dims'"); + } + + // If the configuration is non-batching, then no input or output + // reshape can be empty as that would mean that input or output was + // always empty (no data). + if (io.has_reshape() && (io.reshape().shape_size() == 0) && + (max_batch_size == 0)) { + return Status( + Status::Code::INVALID_ARG, + message_prefix + + "cannot have empty reshape for non-batching model as scalar " + "tensors are not supported"); + } + + for (auto dim : io.dims()) { + // Dimension cannot be 0. + if ((dim < 1) && (dim != triton::common::WILDCARD_DIM)) { + return Status( + Status::Code::INVALID_ARG, + message_prefix + "dimension must be integer >= 1, or " + + std::to_string(triton::common::WILDCARD_DIM) + + " to indicate a variable-size dimension"); + } + } + + if (io.has_reshape()) { + // Zeros are not allowed in reshape. + for (auto dim : io.reshape().shape()) { + if ((dim < 1) && (dim != triton::common::WILDCARD_DIM)) { + return Status( + Status::Code::INVALID_ARG, + message_prefix + "reshape dimensions must be integer >= 1, or " + + std::to_string(triton::common::WILDCARD_DIM) + + " to indicate a variable-size dimension"); + } + } + + const int64_t dims_size = triton::common::GetElementCount(io.dims()); + const int64_t reshape_size = + triton::common::GetElementCount(io.reshape().shape()); + + // dims and reshape must both have same element count + // or both have variable-size dimension. + // Special case for empty reshape... expect dims to have element + // count of 1. + if ((dims_size != reshape_size) && + ((reshape_size != 0) || (dims_size != 1))) { + return Status( + Status::Code::INVALID_ARG, + message_prefix + "has different size for dims and reshape"); + } + + // shape contains variable-size dimension, in this case we compare if + // each pair of the trunks separated by variable-size dimension has + // the same element count. For instance, from [2, 4, -1, 6] to [8, -1, 1, 6] + // is valid reshape as 2 * 4 = 8 and 6 = 1 * 6. + if (dims_size == -1) { + std::vector dim_element_cnts; + std::vector reshape_element_cnts; + int64_t current_cnt = 1; + for (const auto& dim : io.dims()) { + if (dim != -1) { + current_cnt *= dim; + } else { + dim_element_cnts.push_back(current_cnt); + current_cnt = 1; + } + } + dim_element_cnts.push_back(current_cnt); + + current_cnt = 1; + for (const auto& dim : io.reshape().shape()) { + if (dim != -1) { + current_cnt *= dim; + } else { + reshape_element_cnts.push_back(current_cnt); + current_cnt = 1; + } + } + reshape_element_cnts.push_back(current_cnt); + + if (dim_element_cnts.size() != reshape_element_cnts.size()) { + return Status( + Status::Code::INVALID_ARG, + message_prefix + + "has different number of variable-size dimensions for dims " + "and reshape"); + } + for (size_t idx = 0; idx < dim_element_cnts.size(); idx++) { + if (dim_element_cnts[idx] != reshape_element_cnts[idx]) { + return Status( + Status::Code::INVALID_ARG, + message_prefix + "has different size for dims and reshape"); + } + } + } + } + + return Status::Success; +} + +} // namespace + +Status +GetModelVersionFromPath(const std::string& path, int64_t* version) +{ + auto version_dir = BaseName(path); + + // Determine the version from the last segment of 'path' + try { + *version = std::atoll(version_dir.c_str()); + } + catch (...) { + return Status( + Status::Code::INTERNAL, + "unable to determine model version from " + path); + } + + return Status::Success; +} + +Status +GetBooleanSequenceControlProperties( + const inference::ModelSequenceBatching& batcher, + const std::string& model_name, + const inference::ModelSequenceBatching::Control::Kind control_kind, + const bool required, std::string* tensor_name, + inference::DataType* tensor_datatype, float* fp32_false_value, + float* fp32_true_value, int32_t* int32_false_value, + int32_t* int32_true_value, bool* bool_false_value, bool* bool_true_value) +{ + // Make sure same tensor is not configured for multiple controls + std::set seen_tensors; + + // Make sure the control kind is not mentioned multiple times. + bool seen_control = false; + + for (const auto& control_input : batcher.control_input()) { + if (control_input.name().empty()) { + return Status( + Status::Code::INVALID_ARG, + "sequence batching control tensor must have a name for " + + model_name); + } + + if (seen_tensors.find(control_input.name()) != seen_tensors.end()) { + return Status( + Status::Code::INVALID_ARG, + "sequence batching control tensor '" + control_input.name() + + "' is specified for multiple control kinds for " + model_name); + } + + seen_tensors.insert(control_input.name()); + + for (const auto& c : control_input.control()) { + if (c.kind() == control_kind) { + if (seen_control) { + return Status( + Status::Code::INVALID_ARG, + "sequence batching specifies multiple " + + inference::ModelSequenceBatching_Control_Kind_Name( + control_kind) + + " tensors for " + model_name); + } + + *tensor_name = control_input.name(); + seen_control = true; + + // Make sure only one of int, float, or bool type is specified. + if (!((c.int32_false_true_size() != 0) || + (c.fp32_false_true_size() != 0) || + (c.bool_false_true_size() != 0))) { + return Status( + Status::Code::INVALID_ARG, + "sequence batching must specify either 'int32_false_true', " + "'fp32_false_true' or 'bool_false_true' for " + + inference::ModelSequenceBatching_Control_Kind_Name( + control_kind) + + " for " + model_name); + } else if ( + ((c.int32_false_true_size() != 0) && + (c.fp32_false_true_size() != 0)) || + ((c.int32_false_true_size() != 0) && + (c.bool_false_true_size() != 0)) || + ((c.fp32_false_true_size() != 0) && + (c.bool_false_true_size() != 0))) { + return Status( + Status::Code::INVALID_ARG, + "sequence batching specifies more than one from " + "'int32_false_true', 'fp32_false_true' and 'bool_false_true' " + "for " + + inference::ModelSequenceBatching_Control_Kind_Name( + control_kind) + + " for " + model_name); + } + + if (c.int32_false_true_size() > 0) { + if (c.int32_false_true_size() != 2) { + return Status( + Status::Code::INVALID_ARG, + "sequence batching control 'int32_false_true' must have " + "exactly 2 entries for " + + inference::ModelSequenceBatching_Control_Kind_Name( + control_kind) + + " for " + model_name); + } + + if (tensor_datatype != nullptr) { + *tensor_datatype = inference::DataType::TYPE_INT32; + } + if (int32_false_value != nullptr) { + *int32_false_value = c.int32_false_true(0); + } + if (int32_true_value != nullptr) { + *int32_true_value = c.int32_false_true(1); + } + } else if (c.fp32_false_true_size() > 0) { + if (c.fp32_false_true_size() != 2) { + return Status( + Status::Code::INVALID_ARG, + "sequence batching control 'fp32_false_true' must have exactly " + "2 entries for " + + inference::ModelSequenceBatching_Control_Kind_Name( + control_kind) + + " for " + model_name); + } + + if (tensor_datatype != nullptr) { + *tensor_datatype = inference::DataType::TYPE_FP32; + } + if (fp32_false_value != nullptr) { + *fp32_false_value = c.fp32_false_true(0); + } + if (fp32_true_value != nullptr) { + *fp32_true_value = c.fp32_false_true(1); + } + } else { + if (c.bool_false_true_size() != 2) { + return Status( + Status::Code::INVALID_ARG, + "sequence batching control 'bool_false_true' must have exactly " + "2 entries for " + + inference::ModelSequenceBatching_Control_Kind_Name( + control_kind) + + " for " + model_name); + } + + if (tensor_datatype != nullptr) { + *tensor_datatype = inference::DataType::TYPE_BOOL; + } + if (bool_false_value != nullptr) { + *bool_false_value = c.bool_false_true(0); + } + if (bool_true_value != nullptr) { + *bool_true_value = c.bool_false_true(1); + } + } + } + } + } + + if (!seen_control) { + if (required) { + return Status( + Status::Code::INVALID_ARG, + "sequence batching control tensor must specify a " + + inference::ModelSequenceBatching_Control_Kind_Name(control_kind) + + " value for " + model_name); + } + + tensor_name->clear(); + } + + return Status::Success; +} + +Status +GetTypedSequenceControlProperties( + const inference::ModelSequenceBatching& batcher, + const std::string& model_name, + const inference::ModelSequenceBatching::Control::Kind control_kind, + const bool required, std::string* tensor_name, + inference::DataType* tensor_datatype) +{ + // Make sure same tensor is not configured for multiple controls + std::set seen_tensors; + + // Make sure the control kind is not mentioned multiple times. + bool seen_control = false; + + for (const auto& control_input : batcher.control_input()) { + if (control_input.name().empty()) { + return Status( + Status::Code::INVALID_ARG, + "sequence batching control tensor must have a name for " + + model_name); + } + + if (seen_tensors.find(control_input.name()) != seen_tensors.end()) { + return Status( + Status::Code::INVALID_ARG, + "sequence batching control tensor '" + control_input.name() + + "' is specified for multiple control kinds for " + model_name); + } + + seen_tensors.insert(control_input.name()); + + for (const auto& c : control_input.control()) { + if (c.kind() == control_kind) { + if (seen_control) { + return Status( + Status::Code::INVALID_ARG, + "sequence batching specifies multiple " + + inference::ModelSequenceBatching_Control_Kind_Name( + control_kind) + + " tensors for " + model_name); + } + + *tensor_name = control_input.name(); + if (tensor_datatype != nullptr) { + *tensor_datatype = c.data_type(); + } + + seen_control = true; + + if ((c.int32_false_true_size() > 0) || (c.fp32_false_true_size() > 0) || + (c.bool_false_true_size() > 0)) { + return Status( + Status::Code::INVALID_ARG, + "sequence batching must not specify either 'int32_false_true', " + "'fp32_false_true' or 'bool_false_true' for " + + inference::ModelSequenceBatching_Control_Kind_Name( + control_kind) + + " for " + model_name); + } + } + } + } + + if (!seen_control) { + if (required) { + return Status( + Status::Code::INVALID_ARG, + "sequence batching control tensor must specify a " + + inference::ModelSequenceBatching_Control_Kind_Name(control_kind) + + " value for " + model_name); + } + + tensor_name->clear(); + } + + return Status::Success; +} + +Status +GetNormalizedModelConfig( + const std::string& model_name, const std::string& path, + const double min_compute_capability, inference::ModelConfig* config) +{ + // Server-side autofill only sets certain backend fields for the models that + // belong to limited backends for backwards-compatibility. See TensorRT + // backend, ONNX Runtime backend, OpenVINO backend, TensorFLow backend, and + // PyTorch backend. + // Extracting detailed information is delegated to the backend implementation + // to auto-complete. + RETURN_IF_ERROR( + AutoCompleteBackendFields(model_name, std::string(path), config)); + LOG_VERBOSE(1) << "Server side auto-completed config: " + << config->DebugString(); + + RETURN_IF_ERROR(NormalizeModelConfig(min_compute_capability, config)); + + return Status::Success; +} + +Status +NormalizeModelConfig( + const double min_compute_capability, inference::ModelConfig* config) +{ + // If version_policy is not specified, default to Latest 1 version. + if (!config->has_version_policy()) { + inference::ModelVersionPolicy::Latest latest; + latest.set_num_versions(1); + config->mutable_version_policy()->mutable_latest()->CopyFrom(latest); + } + + // If dynamic batching is specified... + if (config->has_dynamic_batching()) { + // If preferred batch size is not specified set it to + // max-batch-size. + if (config->dynamic_batching().preferred_batch_size().size() == 0) { + auto mutable_preferred_batch_size = + config->mutable_dynamic_batching()->mutable_preferred_batch_size(); + if (config->max_batch_size() > 0) { + mutable_preferred_batch_size->Add(config->max_batch_size()); + } + } + } + + // If sequence batching is specified... + if (config->has_sequence_batching()) { + // Set default idle is not specified. + if (config->sequence_batching().max_sequence_idle_microseconds() == 0) { + config->mutable_sequence_batching()->set_max_sequence_idle_microseconds( + SEQUENCE_IDLE_DEFAULT_MICROSECONDS); + } + + if (config->sequence_batching().has_oldest()) { + // If preferred batch size is not specified set it to + // max-batch-size. + if (config->sequence_batching().oldest().preferred_batch_size().size() == + 0) { + auto mutable_preferred_batch_size = + config->mutable_sequence_batching() + ->mutable_oldest() + ->mutable_preferred_batch_size(); + if (config->max_batch_size() > 0) { + mutable_preferred_batch_size->Add(config->max_batch_size()); + } + } + } + } + + // If model ensembling is specified, don't attempt to normalize instance_group + // as it is not allowed in ensemble scheduling + if (!config->has_ensemble_scheduling()) { + auto optimization = config->mutable_optimization(); + if (!optimization->has_input_pinned_memory()) { + optimization->mutable_input_pinned_memory()->set_enable(true); + } + if (!optimization->has_output_pinned_memory()) { + optimization->mutable_output_pinned_memory()->set_enable(true); + } + } + + return Status::Success; +} + +Status +NormalizeInstanceGroup( + const double min_compute_capability, + const std::vector& preferred_groups, + inference::ModelConfig* config) +{ + // Instance group setting doesn't apply to ensemble + if (config->has_ensemble_scheduling()) { + return Status::Success; + } + + // Creates a set of supported GPU device ids + std::set supported_gpus; +#ifdef TRITON_ENABLE_GPU + // Get the total number of GPUs from the runtime library. + Status status = GetSupportedGPUs(&supported_gpus, min_compute_capability); + if (!status.IsOk()) { + return status; + } + +#endif // TRITON_ENABLE_GPU + + // Make sure there is at least one instance_group. + if (config->instance_group().empty()) { + inference::ModelInstanceGroup* group = config->add_instance_group(); + group->set_name(config->name()); + + for (const auto& pg : preferred_groups) { + group->set_kind(pg.kind()); + group->set_count(pg.count()); + // handle preferred GPU setting differently based on kind + if (pg.kind() == inference::ModelInstanceGroup::KIND_GPU) { + // Don't use preferred group with KIND_GPU if there is no GPU. + if (supported_gpus.empty()) { + continue; + } + // If preferred group sets GPUs, limit deployment onto those that + // are also listed in supported gpus + if (!pg.gpus().empty()) { + for (const int32_t gid : pg.gpus()) { + if (supported_gpus.find(gid) != supported_gpus.end()) { + group->add_gpus(gid); + } + } + } + break; + } else if (pg.kind() == inference::ModelInstanceGroup::KIND_AUTO) { + // if AUTO, then set preferred GPU as is, to align with KIND_AUTO + // deduction specified below + for (const int32_t gid : pg.gpus()) { + group->add_gpus(gid); + } + break; + } + // Other kind should not set GPUs + break; + } + } + + // Assign default name, kind and count to each instance group that + // doesn't give those values explicitly. For KIND_GPU, set GPUs to + // all available if not specified explicitly. + size_t cnt = 0; + for (auto& group : *config->mutable_instance_group()) { + // Name + if (group.name().empty()) { + group.set_name(config->name() + "_" + std::to_string(cnt)); + } + cnt++; + + // For KIND_AUTO... if there are no GPUs or if any of the listed + // 'gpu's are not present, then use KIND_CPU. + if (group.kind() == inference::ModelInstanceGroup::KIND_AUTO) { + if (supported_gpus.empty()) { + group.set_kind(inference::ModelInstanceGroup::KIND_CPU); + } else { + for (const int32_t gid : group.gpus()) { + if (supported_gpus.find(gid) == supported_gpus.end()) { + group.set_kind(inference::ModelInstanceGroup::KIND_CPU); + break; + } + } + } + + if (group.kind() == inference::ModelInstanceGroup::KIND_AUTO) { + group.set_kind(inference::ModelInstanceGroup::KIND_GPU); + } + } + + // KIND is resolved at this point + for (const auto& pg : preferred_groups) { + if (group.kind() != pg.kind()) { + continue; + } + + // Limit the GPU setting within what is specified in the preferred group, + // if no available GPU then skip to next preferred group + if ((group.kind() == inference::ModelInstanceGroup::KIND_GPU) && + group.gpus().empty() && !pg.gpus().empty()) { + for (const int32_t gid : pg.gpus()) { + if (supported_gpus.find(gid) != supported_gpus.end()) { + group.add_gpus(gid); + } + } + if (group.gpus().empty()) { + continue; + } + } + if ((group.count() < 1) && (pg.count() > 0)) { + group.set_count(pg.count()); + } + } + + // Set Triton default if the fields are not set from preferred group + // Count + if (group.count() < 1) { + RETURN_IF_ERROR(SetDefaultInstanceCount(&group, config->backend())); + } + + // GPUs + if ((group.kind() == inference::ModelInstanceGroup::KIND_GPU) && + (group.gpus().size() == 0)) { + for (auto d : supported_gpus) { + group.add_gpus(d); + } + } + } + + return Status::Success; +} + +Status +LocalizePythonBackendExecutionEnvironmentPath( + const std::string& model_path, inference::ModelConfig* config, + std::shared_ptr* localized_model_dir) +{ + if (config->backend() == "python") { + if (config->parameters().contains("EXECUTION_ENV_PATH")) { + // Read EXECUTION_ENV_PATH + std::string exec_env_path = + config->parameters().at("EXECUTION_ENV_PATH").string_value(); + // Replace model directory variable with model_path + std::string model_dir_var = "$$TRITON_MODEL_DIRECTORY"; + if (exec_env_path.substr(0, model_dir_var.size()) == model_dir_var) { + exec_env_path.replace(0, model_dir_var.size(), model_path); + } + // Collapse any .. in the path + std::string abs_exec_env_path; + std::size_t prev_pos = exec_env_path.size(); + std::size_t pos = exec_env_path.find_last_of('/', prev_pos - 1); + int skip = 0; + while (pos != std::string::npos && prev_pos > 0) { + if (!skip) { + abs_exec_env_path = + exec_env_path.substr(pos, prev_pos - pos) + abs_exec_env_path; + } + skip = skip > 0 ? skip - 1 : skip; + if (pos >= 3 && exec_env_path.substr(pos - 3, 3) == "/..") { + skip += 2; + } + prev_pos = pos; + pos = exec_env_path.find_last_of('/', prev_pos - 1); + } + abs_exec_env_path = exec_env_path.substr(0, prev_pos) + abs_exec_env_path; + // Localize iff abs_exec_env_path is outside the model directory + std::string model_path_slash = + model_path.back() == '/' ? model_path : model_path + "/"; + if (abs_exec_env_path.substr(0, model_path_slash.size()) != + model_path_slash) { + // Localize the file + std::shared_ptr localized_exec_env_path; + RETURN_IF_ERROR( + LocalizePath(abs_exec_env_path, &localized_exec_env_path)); + // Persist the localized temporary path + (*localized_model_dir) + ->other_localized_path.push_back(localized_exec_env_path); + // Rewrite EXECUTION_ENV_PATH + config->mutable_parameters() + ->at("EXECUTION_ENV_PATH") + .set_string_value(localized_exec_env_path->Path()); + } + } + } + return Status::Success; +} + +Status +SetDefaultInstanceCount( + inference::ModelInstanceGroup* group, const std::string& backend) +{ + group->set_count(1); + + // Backends opt into the default_cpu_instance_count since + // some backends (pytorch, OpenVINO) don't perform well/have high overhead + // when using multiple instances. + const int default_cpu_instance_count = 2; + bool use_default_cpu_instance_count = + (backend == kTensorFlowBackend) || (backend == kOnnxRuntimeBackend); + if (group->kind() == inference::ModelInstanceGroup::KIND_CPU && + use_default_cpu_instance_count) { + group->set_count(default_cpu_instance_count); + } + + return Status::Success; +} + +Status +AutoCompleteBackendFields( + const std::string& model_name, const std::string& model_path, + inference::ModelConfig* config) +{ + std::set version_dirs; + RETURN_IF_ERROR(GetDirectorySubdirs(model_path, &version_dirs)); + + // There must be at least one version directory that we can inspect to + // attempt to determine the platform. If not, we skip autofill with file name. + // For now we allow multiple versions and only inspect the first verison + // directory to ensure it is valid. We can add more aggressive checks later. + const bool has_version = (version_dirs.size() != 0); + const auto version_path = + has_version ? JoinPath({model_path, *(version_dirs.begin())}) : ""; + std::set version_dir_content; + if (has_version) { + RETURN_IF_ERROR(GetDirectoryContents(version_path, &version_dir_content)); + } + + // If the model name is not given in the configuration, set if based + // on the model path. + if (config->name().empty()) { + config->set_name(model_name); + } + + // Trying to fill the 'backend', 'default_model_filename' field. + + // TensorFlow + // For TF backend, the platform is required + if (config->platform().empty()) { + // Check 'backend', 'default_model_filename', and the actual directory + // to determine the platform + if (config->backend().empty() || + (config->backend() == kTensorFlowBackend)) { + if (config->default_model_filename() == kTensorFlowSavedModelFilename) { + config->set_platform(kTensorFlowSavedModelPlatform); + } else if ( + config->default_model_filename() == kTensorFlowGraphDefFilename) { + config->set_platform(kTensorFlowGraphDefPlatform); + } else if (config->default_model_filename().empty() && has_version) { + bool is_dir = false; + if (version_dir_content.find(kTensorFlowSavedModelFilename) != + version_dir_content.end()) { + RETURN_IF_ERROR(IsDirectory( + JoinPath({version_path, kTensorFlowSavedModelFilename}), + &is_dir)); + if (is_dir) { + config->set_platform(kTensorFlowSavedModelPlatform); + } + } + if (version_dir_content.find(kTensorFlowGraphDefFilename) != + version_dir_content.end()) { + RETURN_IF_ERROR(IsDirectory( + JoinPath({version_path, kTensorFlowGraphDefFilename}), &is_dir)); + if (!is_dir) { + config->set_platform(kTensorFlowGraphDefPlatform); + } + } + } + } + } + + // Fill 'backend' and 'default_model_filename' if missing + if ((config->platform() == kTensorFlowSavedModelPlatform) || + (config->platform() == kTensorFlowGraphDefPlatform)) { + if (config->backend().empty()) { + config->set_backend(kTensorFlowBackend); + } + if (config->default_model_filename().empty()) { + if (config->platform() == kTensorFlowSavedModelPlatform) { + config->set_default_model_filename(kTensorFlowSavedModelFilename); + } else { + config->set_default_model_filename(kTensorFlowGraphDefFilename); + } + } + return Status::Success; + } + + // TensorRT + if (config->backend().empty()) { + if ((config->platform() == kTensorRTPlanPlatform) || + (config->default_model_filename() == kTensorRTPlanFilename)) { + config->set_backend(kTensorRTBackend); + } else if ( + config->platform().empty() && + config->default_model_filename().empty() && has_version) { + bool is_dir = false; + if (version_dir_content.find(kTensorRTPlanFilename) != + version_dir_content.end()) { + RETURN_IF_ERROR(IsDirectory( + JoinPath({version_path, kTensorRTPlanFilename}), &is_dir)); + if (!is_dir) { + config->set_backend(kTensorRTBackend); + } + } + } + } + if (config->backend() == kTensorRTBackend) { + if (config->platform().empty()) { + config->set_platform(kTensorRTPlanPlatform); + } + if (config->default_model_filename().empty()) { + config->set_default_model_filename(kTensorRTPlanFilename); + } + return Status::Success; + } + + // ONNXRuntime + if (config->backend().empty()) { + if ((config->platform() == kOnnxRuntimeOnnxPlatform) || + (config->default_model_filename() == kOnnxRuntimeOnnxFilename)) { + config->set_backend(kOnnxRuntimeBackend); + } else if ( + config->platform().empty() && + config->default_model_filename().empty() && has_version) { + if (version_dir_content.find(kOnnxRuntimeOnnxFilename) != + version_dir_content.end()) { + // ONNX model can be a file or a directory in the case of large model + config->set_backend(kOnnxRuntimeBackend); + } + } + } + if (config->backend() == kOnnxRuntimeBackend) { + if (config->platform().empty()) { + config->set_platform(kOnnxRuntimeOnnxPlatform); + } + if (config->default_model_filename().empty()) { + config->set_default_model_filename(kOnnxRuntimeOnnxFilename); + } + return Status::Success; + } + + // OpenVINO + if (config->backend().empty()) { + if (config->default_model_filename() == kOpenVINORuntimeOpenVINOFilename) { + config->set_backend(kOpenVINORuntimeBackend); + } else if ( + config->platform().empty() && + config->default_model_filename().empty() && has_version) { + if (version_dir_content.find(kOpenVINORuntimeOpenVINOFilename) != + version_dir_content.end()) { + config->set_backend(kOpenVINORuntimeBackend); + } + } + } + if (config->backend() == kOpenVINORuntimeBackend) { + if (config->default_model_filename().empty()) { + config->set_default_model_filename(kOpenVINORuntimeOpenVINOFilename); + } + return Status::Success; + } + + // PyTorch (TorchScript, LibTorch) + if (config->backend().empty()) { + if ((config->platform() == kPyTorchLibTorchPlatform) || + (config->default_model_filename() == kPyTorchLibTorchFilename)) { + config->set_backend(kPyTorchBackend); + } else if ( + config->platform().empty() && + config->default_model_filename().empty() && has_version) { + bool is_dir = false; + if (version_dir_content.find(kPyTorchLibTorchFilename) != + version_dir_content.end()) { + RETURN_IF_ERROR(IsDirectory( + JoinPath({version_path, kPyTorchLibTorchFilename}), &is_dir)); + if (!is_dir) { + config->set_backend(kPyTorchBackend); + } + } + } + } + if (config->backend() == kPyTorchBackend) { + if (config->platform().empty()) { + config->set_platform(kPyTorchLibTorchPlatform); + } + if (config->default_model_filename().empty()) { + config->set_default_model_filename(kPyTorchLibTorchFilename); + } + return Status::Success; + } + + // Python + if (config->backend().empty()) { + if (config->default_model_filename() == kPythonFilename) { + config->set_backend(kPythonBackend); + } else if ( + config->platform().empty() && + config->default_model_filename().empty() && has_version) { + if (version_dir_content.find(kPythonFilename) != + version_dir_content.end()) { + config->set_backend(kPythonBackend); + } + } + } + if (config->backend() == kPythonBackend) { + if (config->default_model_filename().empty()) { + config->set_default_model_filename(kPythonFilename); + } + return Status::Success; + } + + // Custom Backend + // For now, only do the narrowest case, where no info is given in the config. + if (config->backend().empty() && config->platform().empty() && + config->default_model_filename().empty()) { + LOG_VERBOSE(1) << "Could not infer supported backend, so attempting " + "autofill of custom backend."; + // Since we lazily load the backends, we let the model tell us what backend + // to load. We must assume that if the model name conforms to the required + // shape, we parse the backend name out of the model file name. i.e. + // model.identity will set the backend to "identity". + const std::string delimiter = "."; + size_t pos = model_name.find(delimiter, 0); + if (pos == std::string::npos) { + return Status( + triton::common::Error::Code::INVALID_ARG, + ("Invalid model name: Could not determine backend for model '" + + model_name + + "' with no backend in model configuration. Expected model name of " + "the form 'model.'.")); + } + const std::string backend_name = + model_name.substr(pos + 1, std::string::npos); + config->set_backend(backend_name); + config->set_default_model_filename( + (std::string("model.") + backend_name).c_str()); + return Status::Success; + } + + return Status::Success; +} + +Status +ValidateModelIOConfig(const inference::ModelConfig& config) +{ + Status status; + for (const auto& io : config.input()) { + status = ValidateModelInput(io, config.max_batch_size(), config.platform()); + if (!status.IsOk()) { + return Status( + status.StatusCode(), status.Message() + " for " + config.name()); + } + } + for (const auto& io : config.output()) { + status = + ValidateModelOutput(io, config.max_batch_size(), config.platform()); + if (!status.IsOk()) { + return Status( + status.StatusCode(), status.Message() + " for " + config.name()); + } + } + status = ValidateBatchIO(config); + if (!status.IsOk()) { + return Status( + status.StatusCode(), status.Message() + " for " + config.name()); + } + return Status::Success; +} + +Status +ValidateBatchIO(const inference::ModelConfig& config) +{ + std::set input_names; + std::set output_names; + for (const auto& io : config.input()) { + input_names.emplace(io.name()); + } + for (const auto& io : config.output()) { + output_names.emplace(io.name()); + } + for (const auto& batch_io : config.batch_input()) { + switch (batch_io.kind()) { + case inference::BatchInput::BATCH_ELEMENT_COUNT: + case inference::BatchInput::BATCH_ACCUMULATED_ELEMENT_COUNT: + case inference::BatchInput::BATCH_ACCUMULATED_ELEMENT_COUNT_WITH_ZERO: + case inference::BatchInput::BATCH_MAX_ELEMENT_COUNT_AS_SHAPE: + case inference::BatchInput::BATCH_ITEM_SHAPE: + case inference::BatchInput::BATCH_ITEM_SHAPE_FLATTEN: { + if (batch_io.source_input_size() != 1) { + return Status( + Status::Code::INVALID_ARG, + "batch input kind '" + + inference::BatchInput::Kind_Name(batch_io.kind()) + + "' expects 1 source input, got " + + std::to_string(batch_io.source_input_size())); + } + break; + } + default: + return Status( + Status::Code::INVALID_ARG, + "unknown batch input kind '" + + inference::BatchInput::Kind_Name(batch_io.kind()) + "'"); + } + if ((batch_io.data_type() != inference::DataType::TYPE_INT32) && + (batch_io.data_type() != inference::DataType::TYPE_FP32)) { + return Status( + Status::Code::INVALID_ARG, + "batch input data type must be TYPE_INT32 or TYPE_FP32"); + } + for (const auto& source_name : batch_io.source_input()) { + if (input_names.find(source_name) == input_names.end()) { + return Status( + Status::Code::INVALID_ARG, + "unknown source input name '" + source_name + "'"); + } + } + } + + for (const auto& batch_io : config.batch_output()) { + switch (batch_io.kind()) { + case inference::BatchOutput::BATCH_SCATTER_WITH_INPUT_SHAPE: { + if (batch_io.source_input_size() != 1) { + return Status( + Status::Code::INVALID_ARG, + "batch output kind '" + + inference::BatchOutput::Kind_Name(batch_io.kind()) + + "' expects 1 source input, got " + + std::to_string(batch_io.source_input_size())); + } + break; + } + default: + return Status( + Status::Code::INVALID_ARG, + "unknown batch output kind '" + + inference::BatchOutput::Kind_Name(batch_io.kind()) + "'"); + } + for (const auto& source_name : batch_io.source_input()) { + if (input_names.find(source_name) == input_names.end()) { + return Status( + Status::Code::INVALID_ARG, + "unknown source input name '" + source_name + "'"); + } + } + std::set target_names; + for (const auto& target_name : batch_io.target_name()) { + if (output_names.find(target_name) == output_names.end()) { + return Status( + Status::Code::INVALID_ARG, + "unknown target output name '" + target_name + "'"); + } + if (target_names.emplace(target_name).second == false) { + return Status( + Status::Code::INVALID_ARG, "target output name '" + target_name + + "' can only be specified once"); + } + } + } + return Status::Success; +} + +Status +ValidateModelConfig( + const inference::ModelConfig& config, const double min_compute_capability) +{ + if (config.name().empty()) { + return Status( + Status::Code::INVALID_ARG, "model configuration must specify 'name'"); + } + + if (config.backend().empty()) { + // Expect backend is not empty unless it is ensemble platform. +#ifdef TRITON_ENABLE_ENSEMBLE + if (config.platform() != kEnsemblePlatform) +#endif // TRITON_ENABLE_ENSEMBLE + return Status( + Status::Code::INVALID_ARG, "unexpected platform type '" + + config.platform() + "' for " + + config.name()); + } +#ifdef TRITON_ENABLE_ENSEMBLE + else if (config.platform() == kEnsemblePlatform) { + return Status( + Status::Code::INVALID_ARG, + "Ensemble model '" + config.name() + "' must have platform type '" + + config.platform() + "' and empty backend type"); + } +#endif // TRITON_ENABLE_ENSEMBLE + + if (config.platform().empty() && config.backend().empty()) { + return Status( + Status::Code::INVALID_ARG, + "must specify 'platform' or 'backend' for '" + config.name() + "'"); + } + + // Ensure both platform and backend are referring to known backend, + // or both referring to unknown backend for user-provided backend. + if (GetBackendTypeFromPlatform(config.platform()) != + GetBackendType(config.backend())) { + return Status( + Status::Code::INVALID_ARG, + "unexpected 'platform' and 'backend' pair, got:" + config.platform() + + ", " + config.backend()); + } + + if (config.max_batch_size() < 0) { + return Status( + Status::Code::INVALID_ARG, + "'max_batch_size' must be non-negative value for " + config.name()); + } + + if (!config.has_version_policy()) { + return Status( + Status::Code::INVALID_ARG, + "must specify 'version policy' for " + config.name()); + } + + // If dynamic batching is specified make sure the preferred batch + // sizes are positive and don't exceed maximum batch size. + if (config.has_dynamic_batching()) { + for (const auto size : config.dynamic_batching().preferred_batch_size()) { + if (size <= 0) { + return Status( + Status::Code::INVALID_ARG, + "dynamic batching preferred size must be positive for " + + config.name()); + } + if (size > config.max_batch_size()) { + return Status( + Status::Code::INVALID_ARG, + "dynamic batching preferred size must be <= max batch size for " + + config.name()); + } + } + + // Priority queue is specified + const auto priority_levels = config.dynamic_batching().priority_levels(); + if (priority_levels != 0) { + if ((config.dynamic_batching().default_priority_level() == 0) || + (config.dynamic_batching().default_priority_level() > + priority_levels)) { + return Status( + Status::Code::INVALID_ARG, + "default priority level must be in range [1, " + + std::to_string(priority_levels) + "] for " + config.name()); + } + for (const auto& queue_policy : + config.dynamic_batching().priority_queue_policy()) { + if ((queue_policy.first == 0) || + (queue_policy.first > priority_levels)) { + return Status( + Status::Code::INVALID_ARG, + "priority queue policy must have priority level in range [1, " + + std::to_string(priority_levels) + "] for " + config.name()); + } + } + } + + // preserve ordering option will conflict with priorities and delay policy + if (config.dynamic_batching().preserve_ordering()) { + if (priority_levels > 1) { + return Status( + Status::Code::INVALID_ARG, + "Only one priority level is allowed when 'preserve_ordering' is " + "true for " + + config.name()); + } + const auto& default_policy = + config.dynamic_batching().default_queue_policy(); + if ((default_policy.default_timeout_microseconds() != 0) && + (default_policy.timeout_action() == + inference::ModelQueuePolicy::DELAY)) { + return Status( + Status::Code::INVALID_ARG, + "Queue policy can not have DELAY as timeout action when " + "'preserve_ordering' is true for " + + config.name()); + } + // Also need to check policy in 'priority_queue_policy' + // for single priority case + for (const auto& policy : + config.dynamic_batching().priority_queue_policy()) { + if ((policy.second.default_timeout_microseconds() != 0) && + (policy.second.timeout_action() == + inference::ModelQueuePolicy::DELAY)) { + return Status( + Status::Code::INVALID_ARG, + "Queue policy can not have DELAY as timeout action when " + "'preserve_ordering' is true for " + + config.name()); + } + } + } + } + + // If sequence batching is specified make sure the control is + // specified correctly. + if (config.has_sequence_batching()) { + const auto& batcher = config.sequence_batching(); + + // Check boolean controls... + std::string tensor_name; + RETURN_IF_ERROR(GetBooleanSequenceControlProperties( + batcher, config.name(), + inference::ModelSequenceBatching::Control::CONTROL_SEQUENCE_START, + false /* required */, &tensor_name, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr)); + RETURN_IF_ERROR(GetBooleanSequenceControlProperties( + batcher, config.name(), + inference::ModelSequenceBatching::Control::CONTROL_SEQUENCE_END, + false /* required */, &tensor_name, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr)); + RETURN_IF_ERROR(GetBooleanSequenceControlProperties( + batcher, config.name(), + inference::ModelSequenceBatching::Control::CONTROL_SEQUENCE_READY, + false /* required */, &tensor_name, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr)); + + // Check CORRID control and make sure it is one of the allowed types. + inference::DataType tensor_datatype; + RETURN_IF_ERROR(GetTypedSequenceControlProperties( + batcher, config.name(), + inference::ModelSequenceBatching::Control::CONTROL_SEQUENCE_CORRID, + false /* required */, &tensor_name, &tensor_datatype)); + if (!tensor_name.empty()) { + if ((tensor_datatype != inference::DataType::TYPE_UINT64) && + (tensor_datatype != inference::DataType::TYPE_INT64) && + (tensor_datatype != inference::DataType::TYPE_UINT32) && + (tensor_datatype != inference::DataType::TYPE_INT32) && + (tensor_datatype != inference::DataType::TYPE_STRING)) { + return Status( + Status::Code::INVALID_ARG, + "unexpected data type for control " + + inference::ModelSequenceBatching_Control_Kind_Name( + inference::ModelSequenceBatching::Control:: + CONTROL_SEQUENCE_CORRID) + + " for " + config.name() + + ". Allowed data types are TYPE_UINT64, TYPE_INT64, " + "TYPE_UINT32, " + "TYPE_INT32 and TYPE_STRING"); + } + } + + // If oldest-first strategy is enabled make sure the preferred + // batch sizes are positive and don't exceed maximum batch size. + if (config.sequence_batching().has_oldest()) { + for (const auto size : + config.sequence_batching().oldest().preferred_batch_size()) { + if (size <= 0) { + return Status( + Status::Code::INVALID_ARG, + "sequence batching preferred batch size must be positive for " + + config.name()); + } + if (size > config.max_batch_size()) { + return Status( + Status::Code::INVALID_ARG, + "sequence batching preferred batch size must be <= max batch " + "size for " + + config.name()); + } + } + } + + // If direct strategy is enabled make sure the minimum slot utilization is + // in range (0.0, 1.0] + if (config.sequence_batching().has_direct()) { + if ((config.sequence_batching().direct().minimum_slot_utilization() < + 0.0) || + (config.sequence_batching().direct().minimum_slot_utilization() > + 1.0)) { + return Status( + Status::Code::INVALID_ARG, + "sequence batching minimum slot utilization must be in range " + "(0.0, 1.0] for " + + config.name()); + } + } + } + + // If ensemble scheduling is specified, validate it. Otherwise, + // must validate platform and instance_group + if (config.has_ensemble_scheduling()) { +#ifdef TRITON_ENABLE_ENSEMBLE + RETURN_IF_ERROR(ValidateEnsembleSchedulingConfig(config)); +#else + return Status( + Status::Code::INVALID_ARG, "ensemble scheduling not supported"); +#endif // TRITON_ENABLE_ENSEMBLE + } +#ifdef TRITON_ENABLE_ENSEMBLE + else if (config.platform() == kEnsemblePlatform) { + return Status( + Status::Code::INVALID_ARG, + "ensemble scheduling must be set for ensemble " + config.name() + + " whose platform is " + kEnsemblePlatform); + } +#endif // TRITON_ENABLE_ENSEMBLE + + // FIXME: DLIS-3916 - Response Cache does not yet support decoupled models + if (config.model_transaction_policy().decoupled() && + config.response_cache().enable()) { + return Status( + Status::Code::INVALID_ARG, + "Response Cache does not currently support model " + config.name() + + " with 'decoupled' transaction policy. Please disable the response" + " cache."); + } + + return Status::Success; +} + +Status +ValidateInstanceGroup( + const inference::ModelConfig& config, const double min_compute_capability) +{ + // Instance group setting doesn't apply to ensemble + if (config.has_ensemble_scheduling()) { + return Status::Success; + } + + if (config.instance_group().size() == 0) { + return Status( + Status::Code::INVALID_ARG, + "must specify one or more 'instance group's for " + config.name()); + } + + // Make sure KIND_GPU instance group specifies at least one GPU and + // doesn't specify a non-existent GPU. Make sure non-KIND_GPU does + // not specify any GPUs. +#ifdef TRITON_ENABLE_GPU + std::set supported_gpus; + Status status = GetSupportedGPUs(&supported_gpus, min_compute_capability); + if (!status.IsOk()) { + return status; + } +#endif // TRITON_ENABLE_GPU + + for (const auto& group : config.instance_group()) { + if (group.kind() == inference::ModelInstanceGroup::KIND_MODEL) { + if (group.gpus().size() > 0) { + return Status( + Status::Code::INVALID_ARG, + "instance group " + group.name() + " of model " + config.name() + + " has kind KIND_MODEL but specifies one or more GPUs"); + } + } else if (group.kind() == inference::ModelInstanceGroup::KIND_GPU) { +#if !defined(TRITON_ENABLE_GPU) && !defined(TRITON_ENABLE_MALI_GPU) + return Status( + Status::Code::INVALID_ARG, + "instance group " + group.name() + " of model " + config.name() + + " has kind KIND_GPU but server does not support GPUs"); +#elif defined(TRITON_ENABLE_GPU) + if (group.gpus().size() == 0) { + if (supported_gpus.size() == 0) { + return Status( + Status::Code::INVALID_ARG, + "instance group " + group.name() + " of model " + config.name() + + " has kind KIND_GPU but no GPUs are available"); + } else { + return Status( + Status::Code::INVALID_ARG, + "instance group " + group.name() + " of model " + config.name() + + " has kind KIND_GPU but specifies no GPUs"); + } + } + + for (const int32_t gid : group.gpus()) { + if (supported_gpus.find(gid) == supported_gpus.end()) { + std::string supported_gpus_str; + for (const auto& cc : supported_gpus) { + if (!supported_gpus_str.empty()) { + supported_gpus_str += ", "; + } + supported_gpus_str += std::to_string(cc); + } + return Status( + Status::Code::INVALID_ARG, + "instance group " + group.name() + " of model " + config.name() + + " specifies invalid or unsupported gpu id " + + std::to_string(gid) + + ". GPUs with at least the minimum required CUDA compute " + "compatibility of " + + std::to_string(min_compute_capability) + + " are: " + supported_gpus_str); + } + } +#endif // ! TRITON_ENABLE_GPU && ! TRITON_ENABLE_MALI_GPU + } else if (group.kind() == inference::ModelInstanceGroup::KIND_CPU) { + if (group.gpus().size() > 0) { + return Status( + Status::Code::INVALID_ARG, + "instance group " + group.name() + " of model " + config.name() + + " has kind KIND_CPU but specifies one or more GPUs"); + } + } else { + return Status( + Status::Code::INTERNAL, "instance group " + group.name() + + " of model " + config.name() + + " has unexpected kind KIND_AUTO"); + } + + if ((config.platform() != kTensorRTPlanPlatform) && + !group.profile().empty()) { + return Status( + Status::Code::INVALID_ARG, + "instance group " + group.name() + " of model " + config.name() + + " and platform " + config.platform() + + "specifies profile field which is only supported for " + "TensorRT models"); + } else if (!group.profile().empty()) { + for (const auto& profile : group.profile()) { + int profile_index; + RETURN_IF_ERROR(GetProfileIndex(profile, &profile_index)); + if (profile_index < 0) { + return Status( + Status::Code::INVALID_ARG, + "instance group " + group.name() + " of model " + config.name() + + " and platform " + config.platform() + + " specifies invalid profile " + profile + + ". The field should contain the string representation of a " + "non-negative integer."); + } + } + } + } + return Status::Success; +} + +Status +ValidateModelInput( + const inference::ModelInput& io, int32_t max_batch_size, + const std::string& platform) +{ + RETURN_IF_ERROR(ValidateIOShape(io, max_batch_size, "model input ")); + + if (((io.format() == inference::ModelInput::FORMAT_NHWC) || + (io.format() == inference::ModelInput::FORMAT_NCHW)) && + (io.dims_size() != 3)) { + return Status( + Status::Code::INVALID_ARG, "model input NHWC/NCHW require 3 dims"); + } + + if ((platform != kTensorRTPlanPlatform) && io.is_shape_tensor()) { + return Status( + Status::Code::INVALID_ARG, + "shape tensors are only supported for TensorRT platform"); + } + + return Status::Success; +} + +Status +CheckAllowedModelInput( + const inference::ModelInput& io, const std::set& allowed) +{ + if (allowed.find(io.name()) == allowed.end()) { + std::string astr; + for (const auto& a : allowed) { + if (!astr.empty()) { + astr.append(", "); + } + astr.append(a); + } + + return Status( + Status::Code::INVALID_ARG, "unexpected inference input '" + io.name() + + "', allowed inputs are: " + astr); + } + return Status::Success; +} + +Status +ValidateModelOutput( + const inference::ModelOutput& io, int32_t max_batch_size, + const std::string& platform) +{ + RETURN_IF_ERROR(ValidateIOShape(io, max_batch_size, "model output ")); + + if ((platform != kTensorRTPlanPlatform) && io.is_shape_tensor()) { + return Status( + Status::Code::INVALID_ARG, + "shape tensors are only supported for TensorRT platform"); + } + + return Status::Success; +} + +Status +CheckAllowedModelOutput( + const inference::ModelOutput& io, const std::set& allowed) +{ + if (allowed.find(io.name()) == allowed.end()) { + std::string astr; + for (const auto& a : allowed) { + if (!astr.empty()) { + astr.append(", "); + } + astr.append(a); + } + + return Status( + Status::Code::INVALID_ARG, "unexpected inference output '" + io.name() + + "', allowed outputs are: " + astr); + } + + return Status::Success; +} + +Status +ParseBoolParameter( + const std::string& key, std::string value, bool* parsed_value) +{ + std::transform( + value.begin(), value.end(), value.begin(), + [](unsigned char c) { return std::tolower(c); }); + + if ((value == "true") || (value == "1")) { + *parsed_value = true; + } else if ((value == "false") || (value == "0")) { + *parsed_value = false; + } else { + return Status( + Status::Code::INVALID_ARG, + "failed to convert " + key + " '" + value + "' to boolean value"); + } + + return Status::Success; +} + +Status +ParseLongLongParameter( + const std::string& key, const std::string& value, int64_t* parsed_value) +{ + try { + *parsed_value = std::stoll(value); + } + catch (const std::invalid_argument& ia) { + return Status( + Status::Code::INVALID_ARG, + "failed to convert " + key + " '" + value + "' to integral number"); + } + + return Status::Success; +} + +Status +GetProfileIndex(const std::string& profile_name, int* profile_index) +{ + if (profile_name.empty()) { + return Status(Status::Code::INVALID_ARG, "profile name must not be empty"); + } + + try { + *profile_index = stoi(profile_name); + } + catch (const std::invalid_argument& ia) { + return Status( + Status::Code::INVALID_ARG, + "unable to parse '" + profile_name + "': " + ia.what()); + } + + return Status::Success; +} + +namespace { + +Status +CollectInt64Fields( + google::protobuf::Message* message, const std::string& prefix, + std::set* int64_fields) +{ + const google::protobuf::Descriptor* desc = message->GetDescriptor(); + const google::protobuf::Reflection* refl = message->GetReflection(); + for (int i = 0; i < desc->field_count(); ++i) { + const google::protobuf::FieldDescriptor* field = desc->field(i); + const std::string fullname = prefix + "::" + field->name(); + switch (field->type()) { + case google::protobuf::FieldDescriptor::TYPE_MESSAGE: { + if (field->is_repeated()) { + int rsize = refl->FieldSize(*message, field); + if (rsize == 0) { + refl->AddMessage(message, field); + } + + rsize = refl->FieldSize(*message, field); + for (int r = 0; r < rsize; ++r) { + RETURN_IF_ERROR(CollectInt64Fields( + refl->MutableRepeatedMessage(message, field, r), fullname, + int64_fields)); + } + } else { + RETURN_IF_ERROR(CollectInt64Fields( + refl->MutableMessage(message, field), fullname, int64_fields)); + } + } break; + + case google::protobuf::FieldDescriptor::TYPE_INT64: + case google::protobuf::FieldDescriptor::TYPE_UINT64: + case google::protobuf::FieldDescriptor::TYPE_SINT64: + case google::protobuf::FieldDescriptor::TYPE_FIXED64: + case google::protobuf::FieldDescriptor::TYPE_SFIXED64: + int64_fields->insert(fullname); + break; + + default: + break; + } + } + + return Status::Success; +} + +Status +ValidateModelConfigInt64() +{ + // Must initialize a dummy ModelConfig so that all fields are + // visited. + inference::ModelConfig config; + + std::set int64_fields; + RETURN_IF_ERROR(CollectInt64Fields(&config, "ModelConfig", &int64_fields)); + + LOG_VERBOSE(1) << "ModelConfig 64-bit fields:"; + for (const auto& f : int64_fields) { + LOG_VERBOSE(1) << "\t" << f; + } + + // We expect to find exactly the following fields. If we get an + // error from this code ModelConfig has added or removed a 64-bit + // field and we need to adjust here and in ModelConfigToJson below. + std::set expected{ + "ModelConfig::input::dims", + "ModelConfig::input::reshape::shape", + "ModelConfig::output::dims", + "ModelConfig::output::reshape::shape", + "ModelConfig::version_policy::specific::versions", + "ModelConfig::dynamic_batching::max_queue_delay_microseconds", + "ModelConfig::dynamic_batching::default_queue_policy::default_timeout_" + "microseconds", + "ModelConfig::dynamic_batching::priority_queue_policy::value::default_" + "timeout_microseconds", + "ModelConfig::sequence_batching::direct::max_queue_delay_microseconds", + "ModelConfig::sequence_batching::state::dims", + "ModelConfig::sequence_batching::state::initial_state::dims", + "ModelConfig::sequence_batching::oldest::max_queue_delay_microseconds", + "ModelConfig::sequence_batching::max_sequence_idle_microseconds", + "ModelConfig::ensemble_scheduling::step::model_version", + "ModelConfig::model_warmup::inputs::value::dims", + "ModelConfig::optimization::cuda::graph_spec::input::value::dim", + "ModelConfig::optimization::cuda::graph_spec::graph_lower_bound::input::" + "value::dim", + "ModelConfig::instance_group::secondary_devices::device_id"}; + + if (int64_fields != expected) { + return Status( + Status::Code::INTERNAL, "ModelConfig 64-bit field needs update"); + } + + return Status::Success; +} + +Status +FixInt( + triton::common::TritonJson::Value& document, + triton::common::TritonJson::Value& io, const std::string& name) +{ + triton::common::TritonJson::Value str_value; + if (!io.Find(name.c_str(), &str_value)) { + return Status::Success; + } + + std::string str; + RETURN_IF_ERROR(str_value.AsString(&str)); + + int64_t d; + try { + d = std::atoll(str.c_str()); + } + catch (...) { + return Status( + Status::Code::INTERNAL, + (std::string("unable to convert '") + str + "' to integer")); + } + + str_value.SetInt(d); + + return Status::Success; +} + +Status +FixIntArray( + triton::common::TritonJson::Value& document, + triton::common::TritonJson::Value& io, const std::string& name) +{ + triton::common::TritonJson::Value fixed_shape_array( + document, triton::common::TritonJson::ValueType::ARRAY); + + if (!io.Find(name.c_str())) { + return Status::Success; + } + + triton::common::TritonJson::Value shape_array; + RETURN_IF_ERROR(io.MemberAsArray(name.c_str(), &shape_array)); + for (size_t i = 0; i < shape_array.ArraySize(); ++i) { + std::string str; + RETURN_IF_ERROR(shape_array.IndexAsString(i, &str)); + + int64_t d; + try { + d = std::atoll(str.c_str()); + } + catch (...) { + return Status( + Status::Code::INTERNAL, + (std::string("unable to convert '") + str + "' to integer")); + } + + RETURN_IF_ERROR(fixed_shape_array.AppendInt(d)); + } + + shape_array.Swap(fixed_shape_array); + fixed_shape_array.Release(); + + return Status::Success; +} + +Status +FixObjectArray( + triton::common::TritonJson::Value& document, + triton::common::TritonJson::Value& arr, const std::string& name) +{ + for (size_t i = 0; i < arr.ArraySize(); ++i) { + triton::common::TritonJson::Value obj; + RETURN_IF_ERROR(arr.IndexAsObject(i, &obj)); + RETURN_IF_ERROR(FixInt(document, obj, name)); + } + + return Status::Success; +} + +} // namespace + +Status +ModelConfigToJson( + const inference::ModelConfig& config, const uint32_t config_version, + std::string* json_str) +{ + // Currently only support 'config_version' 1, which is the json + // representation of the ModelConfig protobuf with the int64 fields + // fixes to be actual numbers instead of the string madness done by + // protobuf. + if (config_version != 1) { + return Status( + Status::Code::INVALID_ARG, + std::string("model configuration version ") + + std::to_string(config_version) + + " not supported, supported versions are: 1"); + } + + // Config will have 0 byte size if all fields are with default value, + // in other word the config is empty. + if (config.ByteSizeLong() == 0) { + json_str->clear(); + return Status::Success; + } + + std::string config_json_str; + ::google::protobuf::util::JsonPrintOptions options; + options.preserve_proto_field_names = true; + options.always_print_primitive_fields = true; + ::google::protobuf::util::MessageToJsonString( + config, &config_json_str, options); + + // We need to verify that every field 64-bit field in the + // ModelConfig protobuf is being handled. We hardcode the known + // fields and check just once to make sure everything has been + // handled. We could have this check in a separately compiled CI + // test but it is convenient to keep it here close to the code below + // that actually fixes the 64-bit fields. + { + static std::once_flag fonce; + Status status = Status::Success; + std::call_once(fonce, [&status] { status = ValidateModelConfigInt64(); }); + RETURN_IF_ERROR(status); + } + + // In the json produced by protobuf, int64 and uint64 values are + // represented as strings. Protobuf doesn't provide an option to + // disable this (sigh) so we need to fix it up here as we want the + // json representation of the config to be reasonable json... + triton::common::TritonJson::Value config_json; + config_json.Parse(config_json_str); + + // Fix input::dims, input::reshape::shape, output::dims, + // output::reshape::shape + for (std::string name : {"input", "output"}) { + triton::common::TritonJson::Value ios; + RETURN_IF_ERROR(config_json.MemberAsArray(name.c_str(), &ios)); + for (size_t i = 0; i < ios.ArraySize(); ++i) { + triton::common::TritonJson::Value io; + RETURN_IF_ERROR(ios.IndexAsObject(i, &io)); + RETURN_IF_ERROR(FixIntArray(config_json, io, "dims")); + + triton::common::TritonJson::Value reshape; + if (io.Find("reshape", &reshape)) { + RETURN_IF_ERROR(FixIntArray(config_json, reshape, "shape")); + } + } + } + + // Fix version_policy::specific::versions + { + triton::common::TritonJson::Value vp; + if (config_json.Find("version_policy", &vp)) { + triton::common::TritonJson::Value specific; + if (vp.Find("specific", &specific)) { + RETURN_IF_ERROR(FixIntArray(config_json, specific, "versions")); + } + } + } + + // Fix dynamic_batching::max_queue_delay_microseconds, + // dynamic_batching::default_queue_policy::default_timeout_microseconds, + // dynamic_batching::priority_queue_policy::value::default_timeout_microseconds + { + triton::common::TritonJson::Value db; + if (config_json.Find("dynamic_batching", &db)) { + RETURN_IF_ERROR(FixInt(config_json, db, "max_queue_delay_microseconds")); + triton::common::TritonJson::Value dqp; + if (db.Find("default_queue_policy", &dqp)) { + RETURN_IF_ERROR( + FixInt(config_json, dqp, "default_timeout_microseconds")); + } + triton::common::TritonJson::Value pqp; + if (db.Find("priority_queue_policy", &pqp)) { + // Iterate over each member in 'pqp' and fix... + std::vector members; + RETURN_IF_ERROR(pqp.Members(&members)); + for (const auto& m : members) { + triton::common::TritonJson::Value el; + RETURN_IF_ERROR(pqp.MemberAsObject(m.c_str(), &el)); + RETURN_IF_ERROR( + FixInt(config_json, el, "default_timeout_microseconds")); + } + } + } + } + + // Fix sequence_batching::oldest::max_queue_delay_microseconds, + // sequence_batching::direct::max_queue_delay_microseconds, + // sequence_batching::max_sequence_idle_microseconds + { + triton::common::TritonJson::Value sb; + if (config_json.Find("sequence_batching", &sb)) { + RETURN_IF_ERROR( + FixInt(config_json, sb, "max_sequence_idle_microseconds")); + triton::common::TritonJson::Value oldest; + if (sb.Find("oldest", &oldest)) { + RETURN_IF_ERROR( + FixInt(config_json, oldest, "max_queue_delay_microseconds")); + } + triton::common::TritonJson::Value direct; + if (sb.Find("direct", &direct)) { + RETURN_IF_ERROR( + FixInt(config_json, direct, "max_queue_delay_microseconds")); + } + + triton::common::TritonJson::Value states; + if (sb.Find("state", &states)) { + for (size_t i = 0; i < states.ArraySize(); ++i) { + triton::common::TritonJson::Value state; + RETURN_IF_ERROR(states.IndexAsObject(i, &state)); + RETURN_IF_ERROR(FixIntArray(config_json, state, "dims")); + + triton::common::TritonJson::Value initial_state; + if (sb.Find("initial_state", &initial_state)) { + RETURN_IF_ERROR(FixIntArray(config_json, initial_state, "dims")); + } + } + } + } + } + + // Fix ensemble_scheduling::step::model_version. + { + triton::common::TritonJson::Value ens; + if (config_json.Find("ensemble_scheduling", &ens)) { + triton::common::TritonJson::Value step; + if (ens.Find("step", &step)) { + RETURN_IF_ERROR(FixObjectArray(config_json, step, "model_version")); + } + } + } + + // Fix model_warmup::inputs::value::dims. + { + triton::common::TritonJson::Value warmups; + if (config_json.Find("model_warmup", &warmups)) { + for (size_t i = 0; i < warmups.ArraySize(); ++i) { + triton::common::TritonJson::Value warmup; + RETURN_IF_ERROR(warmups.IndexAsObject(i, &warmup)); + triton::common::TritonJson::Value inputs; + if (warmup.Find("inputs", &inputs)) { + std::vector members; + RETURN_IF_ERROR(inputs.Members(&members)); + for (const auto& m : members) { + triton::common::TritonJson::Value input; + RETURN_IF_ERROR(inputs.MemberAsObject(m.c_str(), &input)); + RETURN_IF_ERROR(FixIntArray(config_json, input, "dims")); + } + } + } + } + } + + // Convert fixed json back the string... + triton::common::TritonJson::WriteBuffer buffer; + RETURN_IF_ERROR(config_json.Write(&buffer)); + *json_str = std::move(buffer.MutableContents()); + + return Status::Success; +} + +Status +JsonToModelConfig( + const std::string& json_config, const uint32_t config_version, + inference::ModelConfig* protobuf_config) +{ + // Currently only support 'config_version' 1, which is the json + // representation of the ModelConfig protobuf matches the representation in + // ModelConfigToJson(). + if (config_version != 1) { + return Status( + Status::Code::INVALID_ARG, + std::string("model configuration version ") + + std::to_string(config_version) + + " not supported, supported versions are: 1"); + } + + ::google::protobuf::util::JsonParseOptions options; + options.case_insensitive_enum_parsing = true; + options.ignore_unknown_fields = false; + auto err = ::google::protobuf::util::JsonStringToMessage( + json_config, protobuf_config, options); + if (!err.ok()) { + return Status(Status::Code::INVALID_ARG, std::string(err.message())); + } + + return Status::Success; +} + +BackendType +GetBackendTypeFromPlatform(const std::string& platform_name) +{ + if ((platform_name == kTensorFlowGraphDefPlatform) || + (platform_name == kTensorFlowSavedModelPlatform)) { + return BackendType::BACKEND_TYPE_TENSORFLOW; + } + + if (platform_name == kTensorRTPlanPlatform) { + return BackendType::BACKEND_TYPE_TENSORRT; + } + + if (platform_name == kOnnxRuntimeOnnxPlatform) { + return BackendType::BACKEND_TYPE_ONNXRUNTIME; + } + + if (platform_name == kPyTorchLibTorchPlatform) { + return BackendType::BACKEND_TYPE_PYTORCH; + } + + return BackendType::BACKEND_TYPE_UNKNOWN; +} + +/// Get the BackendType value for a backend name. +/// \param backend_name The backend name. +/// \return The BackendType or BackendType::UNKNOWN if the platform string +/// is not recognized. +BackendType +GetBackendType(const std::string& backend_name) +{ + if (backend_name == kTensorFlowBackend) { + return BackendType::BACKEND_TYPE_TENSORFLOW; + } + + if (backend_name == kTensorRTBackend) { + return BackendType::BACKEND_TYPE_TENSORRT; + } + + if (backend_name == kOnnxRuntimeBackend) { + return BackendType::BACKEND_TYPE_ONNXRUNTIME; + } + + if (backend_name == kPyTorchBackend) { + return BackendType::BACKEND_TYPE_PYTORCH; + } + + return BackendType::BACKEND_TYPE_UNKNOWN; +} + +TRITONSERVER_DataType +DataTypeToTriton(const inference::DataType dtype) +{ + switch (dtype) { + case inference::DataType::TYPE_BOOL: + return TRITONSERVER_TYPE_BOOL; + case inference::DataType::TYPE_UINT8: + return TRITONSERVER_TYPE_UINT8; + case inference::DataType::TYPE_UINT16: + return TRITONSERVER_TYPE_UINT16; + case inference::DataType::TYPE_UINT32: + return TRITONSERVER_TYPE_UINT32; + case inference::DataType::TYPE_UINT64: + return TRITONSERVER_TYPE_UINT64; + case inference::DataType::TYPE_INT8: + return TRITONSERVER_TYPE_INT8; + case inference::DataType::TYPE_INT16: + return TRITONSERVER_TYPE_INT16; + case inference::DataType::TYPE_INT32: + return TRITONSERVER_TYPE_INT32; + case inference::DataType::TYPE_INT64: + return TRITONSERVER_TYPE_INT64; + case inference::DataType::TYPE_FP16: + return TRITONSERVER_TYPE_FP16; + case inference::DataType::TYPE_FP32: + return TRITONSERVER_TYPE_FP32; + case inference::DataType::TYPE_FP64: + return TRITONSERVER_TYPE_FP64; + case inference::DataType::TYPE_STRING: + return TRITONSERVER_TYPE_BYTES; + case inference::DataType::TYPE_BF16: + return TRITONSERVER_TYPE_BF16; + default: + break; + } + + return TRITONSERVER_TYPE_INVALID; +} + +inference::DataType +TritonToDataType(const TRITONSERVER_DataType dtype) +{ + switch (dtype) { + case TRITONSERVER_TYPE_BOOL: + return inference::DataType::TYPE_BOOL; + case TRITONSERVER_TYPE_UINT8: + return inference::DataType::TYPE_UINT8; + case TRITONSERVER_TYPE_UINT16: + return inference::DataType::TYPE_UINT16; + case TRITONSERVER_TYPE_UINT32: + return inference::DataType::TYPE_UINT32; + case TRITONSERVER_TYPE_UINT64: + return inference::DataType::TYPE_UINT64; + case TRITONSERVER_TYPE_INT8: + return inference::DataType::TYPE_INT8; + case TRITONSERVER_TYPE_INT16: + return inference::DataType::TYPE_INT16; + case TRITONSERVER_TYPE_INT32: + return inference::DataType::TYPE_INT32; + case TRITONSERVER_TYPE_INT64: + return inference::DataType::TYPE_INT64; + case TRITONSERVER_TYPE_FP16: + return inference::DataType::TYPE_FP16; + case TRITONSERVER_TYPE_FP32: + return inference::DataType::TYPE_FP32; + case TRITONSERVER_TYPE_FP64: + return inference::DataType::TYPE_FP64; + case TRITONSERVER_TYPE_BYTES: + return inference::DataType::TYPE_STRING; + case TRITONSERVER_TYPE_BF16: + return inference::DataType::TYPE_BF16; + default: + break; + } + + return inference::DataType::TYPE_INVALID; +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/model_config_utils.h b/3rdparty/core-r22.12/src/model_config_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..f514ef3c48dab55945ec0bf650797e8f49edb7f4 --- /dev/null +++ b/3rdparty/core-r22.12/src/model_config_utils.h @@ -0,0 +1,282 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include "model_config.pb.h" +#include "status.h" +#include "triton/common/model_config.h" +#include "tritonserver_apis.h" +#include "filesystem.h" + +namespace triton { namespace core { + +/// Enumeration for the different backend types. +enum BackendType { + BACKEND_TYPE_UNKNOWN = 0, + BACKEND_TYPE_TENSORRT = 1, + BACKEND_TYPE_TENSORFLOW = 2, + BACKEND_TYPE_ONNXRUNTIME = 3, + BACKEND_TYPE_PYTORCH = 4 +}; + +// Get version of a model from the path containing the model +/// definition file. +/// \param path The path to the model definition file. +/// \param version Returns the version. +/// \return The error status. +Status GetModelVersionFromPath(const std::string& path, int64_t* version); + +/// Get the tensor name, false value, and true value for a boolean +/// sequence batcher control kind. If 'required' is true then must +/// find a tensor for the control. If 'required' is false, return +/// 'tensor_name' as empty-string if the control is not mapped to any +/// tensor. +Status GetBooleanSequenceControlProperties( + const inference::ModelSequenceBatching& batcher, + const std::string& model_name, + const inference::ModelSequenceBatching::Control::Kind control_kind, + const bool required, std::string* tensor_name, + inference::DataType* tensor_datatype, float* fp32_false_value, + float* fp32_true_value, int32_t* int32_false_value, + int32_t* int32_true_value, bool* bool_false_value, bool* bool_true_value); + +/// Get the tensor name and datatype for a non-boolean sequence +/// batcher control kind. If 'required' is true then must find a +/// tensor for the control. If 'required' is false, return +/// 'tensor_name' as empty-string if the control is not mapped to any +/// tensor. 'tensor_datatype' returns the required datatype for the +/// control. +Status GetTypedSequenceControlProperties( + const inference::ModelSequenceBatching& batcher, + const std::string& model_name, + const inference::ModelSequenceBatching::Control::Kind control_kind, + const bool required, std::string* tensor_name, + inference::DataType* tensor_datatype); + +/// Read a ModelConfig and normalize it as expected by model backends. +/// \param path The full-path to the directory containing the +/// model configuration. +/// \param min_compute_capability The minimum support CUDA compute +/// capability. +/// \param config Returns the normalized model configuration. +/// \return The error status. +Status GetNormalizedModelConfig( + const std::string& model_name, const std::string& path, + const double min_compute_capability, inference::ModelConfig* config); + +/// Auto-complete backend related fields (platform, backend and default model +/// filename) if not set, note that only Triton recognized backends will be +/// checked. +/// \param model_name The name of the model. +/// \param model_path The full-path to the directory containing the +/// model configuration. +/// \param config Returns the auto-completed model configuration. +/// \return The error status. +Status AutoCompleteBackendFields( + const std::string& model_name, const std::string& model_path, + inference::ModelConfig* config); + +/// Detects and adds missing fields in the model configuration. +/// \param min_compute_capability The minimum supported CUDA compute +/// capability. +/// \param config The model configuration +/// \return The error status +Status NormalizeModelConfig( + const double min_compute_capability, inference::ModelConfig* config); + +/// [FIXME] better formalize config normalization / validation +/// Detects and adds missing fields in instance group setting. +/// \param min_compute_capability The minimum supported CUDA compute +/// capability. +/// \param config The model configuration +/// \return The error status +Status NormalizeInstanceGroup( + const double min_compute_capability, + const std::vector& preferred_groups, + inference::ModelConfig* config); + +/// [FIXME] Remove once a more permanent solution is implemented (DLIS-4211) +/// Localize EXECUTION_ENV_PATH in python backend. +/// \param model_path The full-path to the directory containing the model +/// configuration, before localization. +/// \param config The model configuration +/// \param localized_model_dir The localized model directory +/// \return The error status +Status LocalizePythonBackendExecutionEnvironmentPath( + const std::string& model_path, inference::ModelConfig* config, + std::shared_ptr* localized_model_dir); + +/// Auto-complete the instance count based on instance kind and backend name. +/// \param group The instance group to set the count for. +/// \param backend The backend name to check against. +/// \return The error status. +Status SetDefaultInstanceCount( + inference::ModelInstanceGroup* group, const std::string& backend); + +/// Validate that a model is specified correctly, except for model inputs +/// and outputs. ValidateModelIOConfig() should be called to +/// validate model inputs and outputs. +/// \param config The model configuration to validate. +/// \param min_compute_capability The minimum support CUDA compute +/// capability. +/// \return The error status. A non-OK status indicates the configuration +/// is not valid. +Status ValidateModelConfig( + const inference::ModelConfig& config, const double min_compute_capability); + +/// [FIXME] better formalize config normalization / validation +/// Validate instance group setting. +/// \param config The model configuration to validate. +/// \param min_compute_capability The minimum support CUDA compute +/// capability. +/// \return The error status. A non-OK status indicates the configuration +/// is not valid. +Status ValidateInstanceGroup( + const inference::ModelConfig& config, const double min_compute_capability); + +/// Validate that a model inputs and outputs are specified correctly. +/// \param config The model configuration to validate. +/// \return The error status. A non-OK status indicates the configuration +/// is not valid. +Status ValidateModelIOConfig(const inference::ModelConfig& config); + +/// Validate that input is specified correctly in a model +/// configuration. +/// \param io The model input. +/// \param max_batch_size The max batch size specified in model configuration. +/// \param platform The platform name +/// \return The error status. A non-OK status indicates the input +/// is not valid. +Status ValidateModelInput( + const inference::ModelInput& io, int32_t max_batch_size, + const std::string& platform); + +/// Validate that an input matches one of the allowed input names. +/// \param io The model input. +/// \param allowed The set of allowed input names. +/// \return The error status. A non-OK status indicates the input +/// is not valid. +Status CheckAllowedModelInput( + const inference::ModelInput& io, const std::set& allowed); + +/// Validate that an output is specified correctly in a model +/// configuration. +/// \param io The model output. +/// \param max_batch_size The max batch size specified in model configuration. +/// \param platform The platform name +/// \return The error status. A non-OK status indicates the output +/// is not valid. +Status ValidateModelOutput( + const inference::ModelOutput& io, int32_t max_batch_size, + const std::string& platform); + +/// Validate that an output matches one of the allowed output names. +/// \param io The model output. +/// \param allowed The set of allowed output names. +/// \return The error status. A non-OK status indicates the output +/// is not valid. +Status CheckAllowedModelOutput( + const inference::ModelOutput& io, const std::set& allowed); + +/// Validate that a model batch inputs and batch outputs are specified +/// correctly. +/// \param config The model configuration to validate.. +/// \return The error status. A non-OK status indicates the batch inputs or +/// batch outputs are not valid. +Status ValidateBatchIO(const inference::ModelConfig& config); + +/// Parse the 'value' of the parameter 'key' into a boolean value. +/// \param key The name of the parameter. +/// \param value The value of the parameter in string. +/// \param parsed_value Return the boolean of the parameter. +/// \return The error status. A non-OK status indicates failure on parsing the +/// value. +Status ParseBoolParameter( + const std::string& key, std::string value, bool* parsed_value); + +/// Parse the 'value' of the parameter 'key' into a long long integer value. +/// \param key The name of the parameter. +/// \param value The value of the parameter in string. +/// \param parsed_value Return the numerical value of the parameter. +/// \return The error status. A non-OK status indicates failure on parsing the +/// value. +Status ParseLongLongParameter( + const std::string& key, const std::string& value, int64_t* parsed_value); + +/// Obtain the 'profile_index' of the 'profile_name'. +/// \param profile_name The name of the profile. +/// \param profile_index Return the index of the profile. +/// \return The error status. A non-OK status indicates failure on getting the +/// value. +Status GetProfileIndex(const std::string& profile_name, int* profile_index); + +/// Convert a model configuration protobuf to the equivalent json. +/// \param config The protobuf model configuration. +/// \param config_version The model configuration will be returned in +/// a format matching this version. If the configuration cannot be +/// represented in the requested version's format then an error will +/// be returned. +/// \param json Returns the equivalent JSON. +/// \return The error status. +Status ModelConfigToJson( + const inference::ModelConfig& config, const uint32_t config_version, + std::string* json_str); + +/// Convert a model configuration JSON to the equivalent protobuf. +/// \param config The JSON model configuration. +/// \param config_version The model configuration will be returned in +/// a format matching this version. If the configuration cannot be +/// represented in the requested version's format then an error will +/// be returned. +/// \param protobuf Returns the equivalent protobuf. +/// \return The error status. +Status JsonToModelConfig( + const std::string& json_config, const uint32_t config_version, + inference::ModelConfig* protobuf_config); + +/// Get the BackendType value for a platform name. +/// \param platform_name The platform name. +/// \return The BackendType or BackendType::UNKNOWN if the platform string +/// is not recognized. +BackendType GetBackendTypeFromPlatform(const std::string& platform_name); + +/// Get the BackendType value for a backend name. +/// \param backend_name The backend name. +/// \return The BackendType or BackendType::UNKNOWN if the platform string +/// is not recognized. +BackendType GetBackendType(const std::string& backend_name); + +/// Get the Triton server data type corresponding to a data type. +/// \param dtype The data type. +/// \return The Triton server data type. +TRITONSERVER_DataType DataTypeToTriton(const inference::DataType dtype); + +/// Get the data type corresponding to a Triton server data type. +/// \param dtype The Triton server data type. +/// \return The data type. +inference::DataType TritonToDataType(const TRITONSERVER_DataType dtype); + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/model_lifecycle.cc b/3rdparty/core-r22.12/src/model_lifecycle.cc new file mode 100644 index 0000000000000000000000000000000000000000..2d37a422b439cdf46465210e577b4f3436480493 --- /dev/null +++ b/3rdparty/core-r22.12/src/model_lifecycle.cc @@ -0,0 +1,740 @@ +// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// + +#include "model_lifecycle.h" + +#include +#include +#include +#include +#include +#include "constants.h" +#include "filesystem.h" +#include "model.h" +#include "model_config_utils.h" +#include "repo_agent.h" +#include "triton/common/logging.h" +#include "triton/common/thread_pool.h" + +#include "backend_model.h" +#ifdef TRITON_ENABLE_ENSEMBLE +#include "ensemble_model.h" +#endif // TRITON_ENABLE_ENSEMBLE + +namespace triton { namespace core { + +const std::string& +ModelReadyStateString(ModelReadyState state) +{ + switch (state) { + case ModelReadyState::UNKNOWN: { + static std::string m("UNKNOWN"); + return m; + } + case ModelReadyState::READY: { + static std::string m("READY"); + return m; + } + case ModelReadyState::UNAVAILABLE: { + static std::string m("UNAVAILABLE"); + return m; + } + case ModelReadyState::LOADING: { + static std::string m("LOADING"); + return m; + } + case ModelReadyState::UNLOADING: { + static std::string m("UNLOADING"); + return m; + } + } + + static std::string m(""); + return m; +} + +namespace { + +Status +VersionsToLoad( + const std::string model_path, const std::string& name, + const inference::ModelConfig& model_config, std::set* versions) +{ + versions->clear(); + + // Get integral number of the version directory + std::set subdirs; + RETURN_IF_ERROR(GetDirectorySubdirs(model_path, &subdirs)); + std::set> existing_versions; + for (const auto& subdir : subdirs) { + if (subdir == kWarmupDataFolder || subdir == kInitialStateFolder) { + continue; + } + if ((subdir.length() > 1) && (subdir.front() == '0')) { + LOG_WARNING << "ignore version directory '" << subdir + << "' which contains leading zeros in its directory name"; + continue; + } + try { + int64_t version = std::stoll(subdir); + existing_versions.insert(version); + } + catch (const std::invalid_argument& ia) { + LOG_WARNING << "ignore version directory '" << subdir + << "' which fails to convert to integral number"; + } + } + + if (model_config.version_policy().has_specific()) { + for (const auto& v : model_config.version_policy().specific().versions()) { + // Only load the specific versions that are presented in model directory + bool version_not_exist = existing_versions.insert(v).second; + if (!version_not_exist) { + versions->emplace(v); + } else { + LOG_ERROR << "version " << v << " is specified for model '" << name + << "', but the version directory is not present"; + } + } + } else { + if (model_config.version_policy().has_latest()) { + // std::set is sorted with std::greater + for (const auto& v : existing_versions) { + if (versions->size() >= + model_config.version_policy().latest().num_versions()) { + break; + } + versions->emplace(v); + } + } else { + // all + versions->insert(existing_versions.begin(), existing_versions.end()); + } + } + + return Status::Success; +} + +// Use smart pointer with custom deleter so that model state will be updated +// to UNAVAILABLE if all smart pointer copies are out of scope +struct ModelDeleter { + ModelDeleter(std::function OnDestroyModel) + : OnDestroyModel_(std::move(OnDestroyModel)) + { + } + + void operator()(Model* model) + { + // The actual model object must be destroyed in a different + // thread. This thread could have a callstack that includes the + // model itself because this deleter could be triggered by + // a request release or response send in the model. Following + // delete will lead to the model destructor which may wait on this + // same thread... so deadlock if we don't use a different thread + // here. + std::function destroy_fn = OnDestroyModel_; + std::thread dthd([model, destroy_fn]() { + delete model; + destroy_fn(); + }); + + dthd.detach(); + } + + // Use to inform the ModelLifeCycle that the model handle is destroyed + std::function OnDestroyModel_; +}; + +} // namespace + +Status +ModelLifeCycle::Create( + InferenceServer* server, const ModelLifeCycleOptions& options, + std::unique_ptr* life_cycle) +{ + std::unique_ptr local_life_cycle( + new ModelLifeCycle(server, options)); + + *life_cycle = std::move(local_life_cycle); + return Status::Success; +} + +const ModelStateMap +ModelLifeCycle::LiveModelStates(bool strict_readiness) +{ + LOG_VERBOSE(2) << "LiveModelStates()"; + std::lock_guard map_lock(map_mtx_); + ModelStateMap live_model_states; + for (auto& model_version : map_) { + bool live = false; + VersionStateMap version_map; + + for (auto& version_model : model_version.second) { + std::lock_guard lock(version_model.second->mtx_); + if (strict_readiness && + version_model.second->state_ != ModelReadyState::READY) { + continue; + } + + // At least one version is live (ready / loading / unloading) + if ((version_model.second->state_ != ModelReadyState::UNKNOWN) && + (version_model.second->state_ != ModelReadyState::UNAVAILABLE)) { + live = true; + version_map[version_model.first] = std::make_pair( + version_model.second->state_, version_model.second->state_reason_); + } + } + + if (live) { + live_model_states[model_version.first] = std::move(version_map); + } + } + return live_model_states; +} + +Status +ModelLifeCycle::StopAllModels() +{ + LOG_VERBOSE(2) << "StopAllModels()"; + std::lock_guard map_lock(map_mtx_); + for (auto& model_version : map_) { + for (auto& version_model : model_version.second) { + if (version_model.second != nullptr) { + std::lock_guard lock(version_model.second->mtx_); + if (version_model.second->model_ != nullptr) { + version_model.second->model_->Stop(); + } + } + } + } + return Status::Success; +} + +const std::set> +ModelLifeCycle::InflightStatus() +{ + LOG_VERBOSE(2) << "InflightStatus()"; + std::lock_guard map_lock(map_mtx_); + std::set> inflight_status; + for (auto& model_version : map_) { + for (auto& version_model : model_version.second) { + if (version_model.second != nullptr) { + std::lock_guard lock(version_model.second->mtx_); + if (version_model.second->model_ != nullptr) { + const auto cnt = + version_model.second->model_->InflightInferenceCount(); + if (cnt != 0) { + inflight_status.emplace( + model_version.first, version_model.first, cnt); + } + } + } + } + } + return inflight_status; +} + +const ModelStateMap +ModelLifeCycle::ModelStates() +{ + LOG_VERBOSE(2) << "ModelStates()"; + std::lock_guard map_lock(map_mtx_); + ModelStateMap model_states; + for (auto& model_version : map_) { + VersionStateMap version_map; + + for (auto& version_model : model_version.second) { + std::lock_guard lock(version_model.second->mtx_); + version_map[version_model.first] = std::make_pair( + version_model.second->state_, version_model.second->state_reason_); + } + + model_states[model_version.first] = std::move(version_map); + } + + return model_states; +} + +const VersionStateMap +ModelLifeCycle::VersionStates(const std::string& model_name) +{ + LOG_VERBOSE(2) << "VersionStates() '" << model_name << "'"; + std::lock_guard map_lock(map_mtx_); + VersionStateMap version_map; + auto mit = map_.find(model_name); + if (mit != map_.end()) { + for (auto& version_model : mit->second) { + std::lock_guard lock(version_model.second->mtx_); + version_map[version_model.first] = std::make_pair( + version_model.second->state_, version_model.second->state_reason_); + } + } + + return version_map; +} + +Status +ModelLifeCycle::ModelState( + const std::string& model_name, const int64_t model_version, + ModelReadyState* state) +{ + std::lock_guard map_lock(map_mtx_); + auto mit = map_.find(model_name); + if (mit != map_.end()) { + auto vit = mit->second.find(model_version); + if (vit != mit->second.end()) { + std::lock_guard lock(vit->second->mtx_); + *state = vit->second->state_; + return Status::Success; + } + } + + return Status( + Status::Code::NOT_FOUND, "model '" + model_name + "', version " + + std::to_string(model_version) + + " is not found"); +} + +Status +ModelLifeCycle::GetModel( + const std::string& model_name, const int64_t version, + std::shared_ptr* model) +{ + LOG_VERBOSE(2) << "GetModel() '" << model_name << "' version " << version; + std::lock_guard map_lock(map_mtx_); + auto mit = map_.find(model_name); + if (mit == map_.end()) { + return Status(Status::Code::NOT_FOUND, "'" + model_name + "' is not found"); + } + + auto vit = mit->second.find(version); + if (vit == mit->second.end()) { + if (version != -1) { + return Status( + Status::Code::NOT_FOUND, "'" + model_name + "' version " + + std::to_string(version) + + " is not found"); + } + + // The case where the request is asking for latest version + int64_t latest = -1; + for (auto& version_model : mit->second) { + if (version_model.first > latest) { + std::lock_guard lock(version_model.second->mtx_); + if (version_model.second->state_ == ModelReadyState::READY) { + latest = version_model.first; + // Tedious, but have to set handle for any "latest" version + // at the moment to avoid edge case like the following: + // "versions : 1 3 2", version 3 is latest but is requested + // to be unloaded when the iterator is examining version 2, + // then 'model' will ensure version 3 is still valid + *model = version_model.second->model_; + } + } + } + if (latest == -1) { + return Status( + Status::Code::NOT_FOUND, + "'" + model_name + "' has no available versions"); + } + } else { + std::lock_guard lock(vit->second->mtx_); + if (vit->second->state_ == ModelReadyState::READY) { + *model = vit->second->model_; + } else { + return Status( + Status::Code::UNAVAILABLE, "'" + model_name + "' version " + + std::to_string(version) + + " is not at ready state"); + } + } + return Status::Success; +} + +Status +ModelLifeCycle::AsyncUnload(const std::string& model_name) +{ + LOG_VERBOSE(2) << "AsyncUnload() '" << model_name << "'"; + std::lock_guard map_lock(map_mtx_); + auto it = map_.find(model_name); + if (it == map_.end()) { + return Status( + Status::Code::INVALID_ARG, "Model to be unloaded has not been served"); + } + + // Get the existing agent models and notify the unload action + const uint64_t now_ns = + std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + for (auto& version : it->second) { + auto& model_info = version.second; + std::lock_guard lock(model_info->mtx_); + model_info->last_update_ns_ = now_ns; + // Unload serving model, for model that is in LOADING state, + // the updated timestamp will be recognized that there is newer update + // on the model info and the load should be aborted + if (model_info->state_ == ModelReadyState::READY) { + if (model_info->agent_model_list_ != nullptr) { + // Only log the error because the model should be unloaded regardless + auto status = model_info->agent_model_list_->InvokeAgentModels( + TRITONREPOAGENT_ACTION_UNLOAD); + if (!status.IsOk()) { + LOG_ERROR + << "Agent model returns error on TRITONREPOAGENT_ACTION_UNLOAD: " + << status.AsString(); + } + } + + // unload + model_info->Release(); + } + } + + return Status::Success; +} + +Status +ModelLifeCycle::AsyncLoad( + const std::string& model_name, const std::string& model_path, + const inference::ModelConfig& model_config, const bool is_config_provided, + const std::shared_ptr& agent_model_list, + std::function&& OnComplete) +{ + LOG_VERBOSE(2) << "AsyncLoad() '" << model_name << "'"; + + std::lock_guard map_lock(map_mtx_); + auto it = map_.find(model_name); + if (it == map_.end()) { + it = map_.emplace(std::make_pair(model_name, VersionMap())).first; + } + + std::set versions; + RETURN_IF_ERROR( + VersionsToLoad(model_path, model_name, model_config, &versions)); + if (versions.empty()) { + return Status( + Status::Code::INVALID_ARG, + "at least one version must be available under the version policy of " + "model '" + + model_name + "'"); + } + + + const uint64_t now_ns = + std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + std::shared_ptr load_tracker( + new LoadTracker(versions.size(), now_ns)); + for (const auto& version : versions) { + std::unique_ptr linfo( + new ModelInfo(model_path, model_config, now_ns)); + ModelInfo* model_info = linfo.get(); + + LOG_INFO << "loading: " << model_name << ":" << version; + model_info->state_ = ModelReadyState::LOADING; + model_info->state_reason_.clear(); + model_info->agent_model_list_ = agent_model_list; + + auto res = it->second.emplace( + std::make_pair(version, std::unique_ptr())); + if (res.second) { + res.first->second = std::move(linfo); + } else { + // There is already a record of this model version. Check if the version + // model is being served, if so, the re-load of the version + // should be performed in background to avoid version downtime. + // Otherwise, swap and monitor state for newly loading model. + auto& serving_model = res.first->second; + std::lock_guard lock(serving_model->mtx_); + if (serving_model->state_ == ModelReadyState::READY) { + background_models_[(uintptr_t)model_info] = std::move(linfo); + } else { + // swap the monitoring model info + serving_model.swap(linfo); + + // further check the state, put to 'background_models_' to keep + // the object valid if the model is LOADING / UNLOADING, because + // the model info will be accessed by a different thread once the + // operation is completed + if ((linfo->state_ == ModelReadyState::LOADING) || + (linfo->state_ == ModelReadyState::UNLOADING)) { + ModelInfo* key = linfo.get(); + background_models_[(uintptr_t)key] = std::move(linfo); + } + } + } + + // Load model asynchronously via thread pool + load_pool_->Enqueue([this, model_name, version, model_info, OnComplete, + load_tracker, is_config_provided]() { + CreateModel(model_name, version, model_info, is_config_provided); + OnLoadComplete(model_name, version, model_info, OnComplete, load_tracker); + }); + } + + return Status::Success; +} + +void +ModelLifeCycle::CreateModel( + const std::string& model_name, const int64_t version, ModelInfo* model_info, + const bool is_config_provided) +{ + LOG_VERBOSE(2) << "CreateModel() '" << model_name << "' version " << version; + const auto& model_config = model_info->model_config_; + + // Create model + Status status; + std::unique_ptr is; + + // If 'backend' is specified in the config then use the new triton + // backend. + if (!model_config.backend().empty()) { + std::unique_ptr model; + status = TritonModel::Create( + server_, model_info->model_path_, cmdline_config_map_, host_policy_map_, + model_name, version, model_config, is_config_provided, &model); + is.reset(model.release()); + } else { +#ifdef TRITON_ENABLE_ENSEMBLE + if (model_info->is_ensemble_) { + status = EnsembleModel::Create( + server_, model_info->model_path_, version, model_config, + is_config_provided, min_compute_capability_, &is); + // Complete label provider with label information from involved models + // Must be done here because involved models may not be able to + // obtained from server because this may happen during server + // initialization. + if (status.IsOk()) { + std::set no_label_outputs; + const auto& label_provider = is->GetLabelProvider(); + for (const auto& output : model_config.output()) { + if (label_provider->GetLabel(output.name(), 0).empty()) { + no_label_outputs.emplace(output.name()); + } + } + for (const auto& element : model_config.ensemble_scheduling().step()) { + for (const auto& pair : element.output_map()) { + // Found model that produce one of the missing output + if (no_label_outputs.find(pair.second) != no_label_outputs.end()) { + std::shared_ptr model; + // Safe to obtain model because the ensemble can't be loaded + // until the involved models are ready + GetModel(element.model_name(), element.model_version(), &model); + label_provider->AddLabels( + pair.second, + model->GetLabelProvider()->GetLabels(pair.first)); + } + } + } + } + } else +#endif // TRITON_ENABLE_ENSEMBLE + { + status = Status( + Status::Code::INVALID_ARG, + "unknown platform '" + model_config.platform() + "'"); + } + } + + std::lock_guard lock(model_info->mtx_); + if (status.IsOk()) { + // [FIXME] better way to manage agent model lifecycle + // Let the deleter also holds a shared pointer copy of agent model list, + // because the reference in ModelInfo can be cleared before the Model object + // is destroyed, and we want agent model to be valid for receiving + // UNLOAD_COMPLETE signal (see ~TritonRepoAgentModelList for detail) + auto agent_model_list = model_info->agent_model_list_; + model_info->model_.reset( + is.release(), ModelDeleter([this, model_name, version, model_info, + agent_model_list]() mutable { + LOG_VERBOSE(2) << "OnDestroy callback() '" << model_name + << "' version " << version; + LOG_INFO << "successfully unloaded '" << model_name << "' version " + << version; + // Update model state as it is fully unloaded + { + std::lock_guard lock(model_info->mtx_); + model_info->state_ = ModelReadyState::UNAVAILABLE; + model_info->state_reason_ = "unloaded"; + } + + // Check if the model info is in background, if so, remove from the + // map + std::lock_guard lk(this->map_mtx_); + auto it = this->background_models_.find((uintptr_t)model_info); + if (it != this->background_models_.end()) { + this->background_models_.erase(it); + } + })); + } else { + LOG_ERROR << "failed to load '" << model_name << "' version " << version + << ": " << status.AsString(); + model_info->state_ = ModelReadyState::UNAVAILABLE; + model_info->state_reason_ = status.AsString(); + } +} + +void +ModelLifeCycle::OnLoadComplete( + const std::string& model_name, const int64_t version, ModelInfo* model_info, + std::function OnComplete, + std::shared_ptr load_tracker) +{ + std::lock_guard tracker_lock(load_tracker->mtx_); + ++load_tracker->completed_version_cnt_; + load_tracker->load_set_[version] = model_info; + // Version will not be marked ready until all versions are + // ready, this simplify the unloading when one version fails to load as + // all other versions won't have inflight requests + if (model_info->state_ != ModelReadyState::LOADING) { + load_tracker->load_failed_ = true; + load_tracker->reason_ += + ("version " + std::to_string(version) + " is at " + + ModelReadyStateString(model_info->state_) + + " state: " + model_info->state_reason_ + ";"); + } + // Check if all versions are completed and finish the load + if (load_tracker->completed_version_cnt_ == + load_tracker->affected_version_cnt_) { + // hold 'map_mtx_' as there will be change onto the model info map + std::lock_guard map_lock(map_mtx_); + auto it = map_.find(model_name); + // Check if the load is the latest frontground action on the model + for (const auto& version_info : it->second) { + if (version_info.second->last_update_ns_ > + load_tracker->last_update_ns_) { + load_tracker->load_failed_ = true; + load_tracker->reason_ = + "Newer operation has been applied to the model lifecycle, current " + "load operation is out-dated."; + break; + } + } + + if (load_tracker->load_failed_) { + // Move agent list out of ModelInfo as it needs to be invoked + // after all ModelInfos are reset + std::shared_ptr lagent_list; + if (model_info->agent_model_list_) { + lagent_list = std::move(model_info->agent_model_list_); + } + // If any of the versions fails to load, abort the load and unload + // all newly loaded versions + for (auto& loaded : load_tracker->load_set_) { + // Unload directly, the object is being managed either in frontground + // or background + std::lock_guard lock(loaded.second->mtx_); + if (loaded.second->model_ != nullptr) { + loaded.second->Release(); + } + } + + if (lagent_list) { + auto status = + lagent_list->InvokeAgentModels(TRITONREPOAGENT_ACTION_LOAD_FAIL); + if (!status.IsOk()) { + LOG_ERROR << "Agent model returns error on " + "TRITONREPOAGENT_ACTION_LOAD_FAIL: " + << status.AsString(); + } + } + } else { + // Unload any previous loaded versions that are still available + for (auto& version_info : it->second) { + auto& mi = version_info.second; + std::lock_guard info_lk(mi->mtx_); + if ((mi->state_ == ModelReadyState::READY) && + (mi->last_update_ns_ < load_tracker->last_update_ns_)) { + if (mi->agent_model_list_ != nullptr) { + auto status = mi->agent_model_list_->InvokeAgentModels( + TRITONREPOAGENT_ACTION_UNLOAD); + if (!status.IsOk()) { + LOG_ERROR << "Agent model returns error on " + "TRITONREPOAGENT_ACTION_UNLOAD: " + << status.AsString(); + } + } + + mi->Release(); + } + } + + // Mark current versions ready and track info in foreground + for (auto& loaded : load_tracker->load_set_) { + std::lock_guard curr_info_lk(loaded.second->mtx_); + loaded.second->state_ = ModelReadyState::READY; + model_info->state_reason_.clear(); + LOG_INFO << "successfully loaded '" << model_name << "' version " + << version; + + auto bit = background_models_.find((uintptr_t)loaded.second); + // Check if the version model is loaded in background, if so, + // replace and unload the current serving version + if (bit != background_models_.end()) { + auto vit = it->second.find(loaded.first); + + // Need to lock the previous model info for in case the model is + // loading / unloading, this ensure the model state is consistent + // even when the load / unload is completed. + std::lock_guard prev_info_lk(vit->second->mtx_); + + // swap previous info into local unique pointer + auto linfo = std::move(bit->second); + vit->second.swap(linfo); + background_models_.erase(bit); + + // if previous info is under change, put into 'background_models_' + if ((linfo->state_ == ModelReadyState::LOADING) || + (linfo->state_ == ModelReadyState::UNLOADING)) { + ModelInfo* key = linfo.get(); + background_models_[(uintptr_t)key] = std::move(linfo); + } + } + } + if (model_info->agent_model_list_) { + auto status = model_info->agent_model_list_->InvokeAgentModels( + TRITONREPOAGENT_ACTION_LOAD_COMPLETE); + if (!status.IsOk()) { + LOG_ERROR << "Agent model returns error on " + "TRITONREPOAGENT_ACTION_LOAD_COMPLETE: " + << status.AsString(); + } + } + } + if (OnComplete != nullptr) { + OnComplete( + load_tracker->load_failed_ + ? Status(Status::Code::INVALID_ARG, load_tracker->reason_) + : Status::Success); + } + } +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/model_lifecycle.h b/3rdparty/core-r22.12/src/model_lifecycle.h new file mode 100644 index 0000000000000000000000000000000000000000..e32b607219bbc155ab1332bb1ffaa0ed2971316d --- /dev/null +++ b/3rdparty/core-r22.12/src/model_lifecycle.h @@ -0,0 +1,324 @@ +// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +#pragma once + +#include +#include +#include +#include "infer_parameter.h" +#include "model_config.pb.h" +#include "repo_agent.h" +#include "status.h" +#include "triton/common/model_config.h" +#include "triton/common/thread_pool.h" + +namespace triton { namespace core { + +struct ModelLifeCycleOptions { + explicit ModelLifeCycleOptions( + const double min_compute_capability, + const triton::common::BackendCmdlineConfigMap& backend_cmdline_config_map, + const triton::common::HostPolicyCmdlineConfigMap& host_policy_map, + const unsigned int model_load_thread_count) + : min_compute_capability_(min_compute_capability), + backend_cmdline_config_map_(backend_cmdline_config_map), + host_policy_map_(host_policy_map), + model_load_thread_count_(model_load_thread_count) + { + } + // The minimum supported CUDA compute capability. + const double min_compute_capability_; + // The backend configuration settings specified on the command-line + const triton::common::BackendCmdlineConfigMap& backend_cmdline_config_map_; + // The host policy setting used when loading models. + const triton::common::HostPolicyCmdlineConfigMap& host_policy_map_; + // Number of the threads to use for concurrently loading models + const unsigned int model_load_thread_count_; +}; + + +/// Readiness status for models. +enum class ModelReadyState { + // The model is in an unknown state. The model is not available for + // inferencing. + UNKNOWN, + + // The model is ready and available for inferencing. + READY, + + // The model is unavailable, indicating that the model failed to + // load or has been implicitly or explicitly unloaded. The model is + // not available for inferencing. + UNAVAILABLE, + + // The model is being loaded by the inference server. The model is + // not available for inferencing. + LOADING, + + // The model is being unloaded by the inference server. The model is + // not available for inferencing. + UNLOADING +}; + +/// Get the string representation for a ModelReadyState +const std::string& ModelReadyStateString(ModelReadyState state); + +using VersionStateMap = + std::map>; +using ModelStateMap = std::map; + +// Helper class to manage the lifecycle of a list of associated agent models +class TritonRepoAgentModelList { + public: + TritonRepoAgentModelList() + : last_action_type_(TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE){}; + ~TritonRepoAgentModelList() + { + // Using destructor to finish the unload lifecycle without + // explicitly managing the last step in ModelLifecycle. + if (last_action_type_ == TRITONREPOAGENT_ACTION_UNLOAD) { + InvokeAgentModels(TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE); + } + } + Status AddAgentModel(std::unique_ptr&& agent_model) + { + agent_models_.emplace_back(std::move(agent_model)); + return Status::Success; + } + + size_t Size() { return agent_models_.size(); } + + TritonRepoAgentModel* Back() { return agent_models_.back().get(); } + + Status InvokeAgentModels(const TRITONREPOAGENT_ActionType action_type) + { + // Special handling for the current model lifecycle implementation, + // the repo agent may be asked to perform UNLOAD action multiple times, + // and the requests after the first should be ignored. + const bool first_unload = + (action_type == TRITONREPOAGENT_ACTION_UNLOAD) && + (last_action_type_ != TRITONREPOAGENT_ACTION_UNLOAD); + if (!first_unload) { + return Status::Success; + } + + last_action_type_ = action_type; + switch (action_type) { + case TRITONREPOAGENT_ACTION_LOAD: + case TRITONREPOAGENT_ACTION_UNLOAD: { + for (size_t idx = 0; idx < agent_models_.size(); ++idx) { + RETURN_IF_ERROR(agent_models_[idx]->InvokeAgent(action_type)); + } + break; + } + case TRITONREPOAGENT_ACTION_LOAD_COMPLETE: + case TRITONREPOAGENT_ACTION_LOAD_FAIL: + case TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE: { + // reverse order + for (size_t one_pass_idx = agent_models_.size(); one_pass_idx > 0; + --one_pass_idx) { + RETURN_IF_ERROR( + agent_models_[one_pass_idx - 1]->InvokeAgent(action_type)); + } + break; + } + } + return Status::Success; + } + + private: + DISALLOW_COPY_AND_ASSIGN(TritonRepoAgentModelList); + + std::vector> agent_models_; + TRITONREPOAGENT_ActionType last_action_type_; +}; + +class InferenceServer; +class Model; + +class ModelLifeCycle { + public: + static Status Create( + InferenceServer* server, const ModelLifeCycleOptions& options, + std::unique_ptr* life_cycle); + + ~ModelLifeCycle() + { + // Explicitly clean up thread pool first to clean up any pending callbacks + // that may modify model lifecycle members + load_pool_.reset(); + map_.clear(); + } + + // Start loading model with specified versions asynchronously. + // All versions that are being served will be unloaded only after + // the load is finished sucessfully. + Status AsyncLoad( + const std::string& model_name, const std::string& model_path, + const inference::ModelConfig& model_config, const bool is_config_provided, + const std::shared_ptr& agent_model_list, + std::function&& OnComplete); + + // Unload model asynchronously. + Status AsyncUnload(const std::string& model_name); + + // Get specified version of the model. Latest ready version will + // be retrieved if 'version' is -1. Return error if the version specified is + // not found or it is not ready. + Status GetModel( + const std::string& model_name, const int64_t version, + std::shared_ptr* model); + + // Get the ModelStateMap representation of the live models. A model is + // live if at least one of the versions is not unknown nor unavailable. + // If 'strict_readiness' is true, a model is only live if + // at least one of the versions is ready. + const ModelStateMap LiveModelStates(bool strict_readiness = false); + + // Get the ModelStateMap representation of the models. + const ModelStateMap ModelStates(); + + // Get the VersionStateMap representation of the specified model. + const VersionStateMap VersionStates(const std::string& model_name); + + // Get the state of a specific model version. + Status ModelState( + const std::string& model_name, const int64_t model_version, + ModelReadyState* state); + + // Instruct the model to stop accepting new inference requests. + Status StopAllModels(); + + // Return the number of in-flight inference if any, model versions + // that don't have in-flight inferences will not be included. + const std::set> InflightStatus(); + + private: + struct ModelInfo { + ModelInfo( + const std::string& model_path, + const inference::ModelConfig& model_config, + const uint64_t last_update_ns) + : model_config_(model_config), model_path_(model_path), +#ifdef TRITON_ENABLE_ENSEMBLE + is_ensemble_(model_config.platform() == kEnsemblePlatform), +#else + is_ensemble_(false), +#endif // TRITON_ENABLE_ENSEMBLE + last_update_ns_(last_update_ns), state_(ModelReadyState::UNKNOWN) + { + } + + // Release the flyweight in ModelInfo object, reflect as 'UNLOADING' in + // model state. Note that 'mtx_' should be acquired before invoking this + // function to prevent possible data race. + void Release() + { + state_ = ModelReadyState::UNLOADING; + state_reason_.clear(); + agent_model_list_.reset(); + model_.reset(); + } + + const inference::ModelConfig model_config_; + const std::string model_path_; + const bool is_ensemble_; + + std::mutex mtx_; + + uint64_t last_update_ns_; + + ModelReadyState state_; + std::string state_reason_; + + // flyweight + std::shared_ptr agent_model_list_; + std::shared_ptr model_; + }; + + struct LoadTracker { + LoadTracker( + const size_t affected_version_cnt, const uint64_t last_update_ns) + : last_update_ns_(last_update_ns), + affected_version_cnt_(affected_version_cnt), load_failed_(false), + completed_version_cnt_(0) + { + } + + const uint64_t last_update_ns_; + const size_t affected_version_cnt_; + + std::mutex mtx_; + + bool load_failed_; + std::string reason_; + size_t completed_version_cnt_; + std::map load_set_; + }; + + ModelLifeCycle(InferenceServer* server, const ModelLifeCycleOptions& options) + : server_(server), + min_compute_capability_(options.min_compute_capability_), + cmdline_config_map_(options.backend_cmdline_config_map_), + host_policy_map_(options.host_policy_map_) + { + load_pool_.reset(new triton::common::ThreadPool( + std::max(1u, options.model_load_thread_count_))); + } + + void CreateModel( + const std::string& model_name, const int64_t version, + ModelInfo* model_info, const bool is_config_provided); + // Callback function template for model load. + // 'OnComplete' needs to be passed by value for now as there can be + // multiple versions to be loaded and each holds a copy of + // the 'OnComplete' callback. + void OnLoadComplete( + const std::string& model_name, const int64_t version, + ModelInfo* model_info, std::function OnComplete, + std::shared_ptr load_tracker); + + + // Mutex for 'map_' and 'background_models_' + std::mutex map_mtx_; + + using VersionMap = std::map>; + using ModelMap = std::map; + ModelMap map_; + // Models that are being loaded / unloaded in background + std::map> background_models_; + + InferenceServer* server_; + const double min_compute_capability_; + const triton::common::BackendCmdlineConfigMap cmdline_config_map_; + const triton::common::HostPolicyCmdlineConfigMap host_policy_map_; + + // Fixed-size thread pool to load models at specified concurrency + std::unique_ptr load_pool_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/model_repository_manager.cc b/3rdparty/core-r22.12/src/model_repository_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..7a8f2b5ca3d67b781a5551e72157e18fadaf8380 --- /dev/null +++ b/3rdparty/core-r22.12/src/model_repository_manager.cc @@ -0,0 +1,1602 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// + +#include "model_repository_manager.h" + +#include +#include +#include +#include +#include +#include "constants.h" +#include "ensemble_utils.h" +#include "filesystem.h" +#include "model.h" +#include "model_config_utils.h" +#include "triton/common/logging.h" + +#include "backend_model.h" +#ifdef TRITON_ENABLE_ENSEMBLE +#include "ensemble_model.h" +#endif // TRITON_ENABLE_ENSEMBLE + +namespace triton { namespace core { + +namespace { + +static std::string file_prefix = "file:"; + +// Internal repo agent used for model file override +class LocalizeRepoAgent : public TritonRepoAgent { + public: + LocalizeRepoAgent() + : TritonRepoAgent("ModelRepositoryManager::LocalizeRepoAgent") + { + // Callbacks below interact with TritonRepoAgentModel directly knowing that + // it is the internal implementation of TRITONREPOAGENT_AgentModel + model_action_fn_ = [](TRITONREPOAGENT_Agent* agent, + TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ActionType action_type) + -> TRITONSERVER_Error* { + auto agent_model = reinterpret_cast(model); + switch (action_type) { + case TRITONREPOAGENT_ACTION_LOAD: { + // localize the override files for model loading, + // as currently the model is expected to load from local directory + const char* temp_dir_cstr = nullptr; + RETURN_TRITONSERVER_ERROR_IF_ERROR( + agent_model->AcquireMutableLocation( + TRITONREPOAGENT_ARTIFACT_FILESYSTEM, &temp_dir_cstr)); + const std::string temp_dir = temp_dir_cstr; + const auto& files = + *reinterpret_cast*>( + agent_model->State()); + bool found_config = false; + for (const auto& file : files) { + if (file->Name() == "config") { + if (file->Type() != TRITONSERVER_PARAMETER_STRING) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "Config parameter 'config' must have string type for its " + "value"); + } + inference::ModelConfig config; + RETURN_TRITONSERVER_ERROR_IF_ERROR(JsonToModelConfig( + file->ValueString(), 1 /* config_version */, &config)); + RETURN_TRITONSERVER_ERROR_IF_ERROR(WriteTextProto( + JoinPath({temp_dir, kModelConfigPbTxt}), config)); + found_config = true; + } else if (file->Name().rfind(file_prefix, 0) == 0) { + if (file->Type() != TRITONSERVER_PARAMETER_BYTES) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("File parameter '") + file->Name() + + "' must have bytes type for its value") + .c_str()); + } + + // Save model file to the instructed directory + // mkdir + const std::string file_path = + JoinPath({temp_dir, file->Name().substr(file_prefix.size())}); + const std::string dir = DirName(file_path); + bool dir_exist = false; + RETURN_TRITONSERVER_ERROR_IF_ERROR(FileExists(dir, &dir_exist)); + if (dir_exist) { + bool is_dir = false; + RETURN_TRITONSERVER_ERROR_IF_ERROR(IsDirectory(dir, &is_dir)); + if (!is_dir) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("Invalid file parameter '") + file->Name() + + "', directory has been created as a file") + .c_str()); + } + } else { + RETURN_TRITONSERVER_ERROR_IF_ERROR( + MakeDirectory(dir, true /* recursive */)); + } + + // write + RETURN_TRITONSERVER_ERROR_IF_ERROR(WriteBinaryFile( + file_path, + reinterpret_cast(file->ValuePointer()), + file->ValueByteSize())); + } + } + if (!found_config) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "Load parameter 'config' must be specified for model file " + "override"); + } + // Commit the temporary directory + RETURN_TRITONSERVER_ERROR_IF_ERROR(agent_model->SetLocation( + TRITONREPOAGENT_ARTIFACT_FILESYSTEM, temp_dir_cstr)); + break; + } + default: + break; + } + return nullptr; // success + }; + + model_fini_fn_ = + [](TRITONREPOAGENT_Agent* agent, + TRITONREPOAGENT_AgentModel* model) -> TRITONSERVER_Error* { + auto agent_model = reinterpret_cast(model); + RETURN_TRITONSERVER_ERROR_IF_ERROR(agent_model->DeleteMutableLocation()); + return nullptr; // success + }; + } +}; + +Status +CreateAgentModelListWithLoadAction( + const inference::ModelConfig& original_model_config, + const std::string& original_model_path, + std::shared_ptr* agent_model_list) +{ + if (original_model_config.has_model_repository_agents()) { + // Trick to append user specified repo agent on top of internal ones + std::shared_ptr lagent_model_list; + if (*agent_model_list != nullptr) { + lagent_model_list = std::move(*agent_model_list); + } else { + lagent_model_list.reset(new TritonRepoAgentModelList()); + } + + FileSystemType filesystem_type; + RETURN_IF_ERROR(GetFileSystemType(original_model_path, &filesystem_type)); + TRITONREPOAGENT_ArtifactType artifact_type = + TRITONREPOAGENT_ARTIFACT_FILESYSTEM; + if (filesystem_type != FileSystemType::LOCAL) { + artifact_type = TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM; + } + const char* location = original_model_path.c_str(); + inference::ModelConfig model_config = original_model_config; + for (const auto& agent_config : + original_model_config.model_repository_agents().agents()) { + std::shared_ptr agent; + RETURN_IF_ERROR( + TritonRepoAgentManager::CreateAgent(agent_config.name(), &agent)); + TritonRepoAgent::Parameters agent_params; + for (const auto& parameter : agent_config.parameters()) { + agent_params.emplace_back(parameter.first, parameter.second); + } + std::unique_ptr agent_model; + if (lagent_model_list->Size() != 0) { + lagent_model_list->Back()->Location(&artifact_type, &location); + const auto config_path = JoinPath({location, kModelConfigPbTxt}); + if (!ReadTextProto(config_path, &model_config).IsOk()) { + model_config.Clear(); + } + } + RETURN_IF_ERROR(TritonRepoAgentModel::Create( + artifact_type, location, model_config, agent, agent_params, + &agent_model)); + RETURN_IF_ERROR(agent_model->InvokeAgent(TRITONREPOAGENT_ACTION_LOAD)); + lagent_model_list->AddAgentModel(std::move(agent_model)); + } + *agent_model_list = std::move(lagent_model_list); + } + return Status::Success; +} + +int64_t +GetModifiedTime(const std::string& path) +{ + // If there is an error in any step the fall-back default + // modification time is 0. This means that in error cases 'path' + // will show as not modified. This is the safe fall-back to avoid + // assuming a model is constantly being modified. + bool path_is_dir; + Status status = IsDirectory(path, &path_is_dir); + if (!status.IsOk()) { + LOG_ERROR << "Failed to determine modification time for '" << path + << "': " << status.AsString(); + return 0; + } + + // If 'path' is a file return its mtime. Otherwise, using the modification + // time of the directory as baseline in case of file deletion + int64_t mtime = 0; + status = FileModificationTime(path, &mtime); + if (!status.IsOk()) { + LOG_ERROR << "Failed to determine modification time for '" << path + << "': " << status.AsString(); + return 0; + } + if (!path_is_dir) { + return mtime; + } + + // 'path' is a directory. Return the most recent mtime of the + // contents of the directory. + std::set contents; + status = GetDirectoryContents(path, &contents); + if (!status.IsOk()) { + LOG_ERROR << "Failed to determine modification time for '" << path + << "': " << status.AsString(); + return 0; + } + + for (const auto& child : contents) { + const auto full_path = JoinPath({path, child}); + mtime = std::max(mtime, GetModifiedTime(full_path)); + } + + return mtime; +} +// Return true if any file in the subdirectory root at 'path' has been +// modified more recently than 'last'. Return the most-recent modified +// time in 'last'. +bool +IsModified(const std::string& path, int64_t* last_ns) +{ + const int64_t repo_ns = GetModifiedTime(path); + bool modified = repo_ns > *last_ns; + *last_ns = repo_ns; + return modified; +} + +} // namespace + +struct ModelRepositoryManager::ModelInfo { + ModelInfo( + const int64_t mtime_nsec, const int64_t prev_mtime_ns, + const std::string& model_path) + : mtime_nsec_(mtime_nsec), prev_mtime_ns_(prev_mtime_ns), + explicitly_load_(true), model_path_(model_path), + is_config_provided_(false) + { + } + ModelInfo() + : mtime_nsec_(0), prev_mtime_ns_(0), explicitly_load_(true), + is_config_provided_(false) + { + } + int64_t mtime_nsec_; + int64_t prev_mtime_ns_; + bool explicitly_load_; + inference::ModelConfig model_config_; + std::string model_path_; + // Temporary location to hold agent model list before creating the model + // the ownership must transfer to ModelLifeCycle to ensure + // the agent model life cycle is handled properly. + std::shared_ptr agent_model_list_; + bool is_config_provided_; +}; + +ModelRepositoryManager::ModelRepositoryManager( + const std::set& repository_paths, const bool autofill, + const bool polling_enabled, const bool model_control_enabled, + const double min_compute_capability, + std::unique_ptr life_cycle) + : repository_paths_(repository_paths), autofill_(autofill), + polling_enabled_(polling_enabled), + model_control_enabled_(model_control_enabled), + min_compute_capability_(min_compute_capability), + model_life_cycle_(std::move(life_cycle)) +{ +} + +ModelRepositoryManager::~ModelRepositoryManager() {} + +Status +ModelRepositoryManager::Create( + InferenceServer* server, const std::string& server_version, + const std::set& repository_paths, + const std::set& startup_models, const bool strict_model_config, + const bool polling_enabled, const bool model_control_enabled, + const ModelLifeCycleOptions& life_cycle_options, + std::unique_ptr* model_repository_manager) +{ + // The rest only matters if repository path is valid directory + for (const auto& path : repository_paths) { + bool path_is_dir; + RETURN_IF_ERROR(IsDirectory(path, &path_is_dir)); + if (!path_is_dir) { + return Status( + Status::Code::INVALID_ARG, + "repository path is not a valid directory"); + } + } + + if (polling_enabled && model_control_enabled) { + return Status( + Status::Code::INVALID_ARG, + "cannot enable both polling and explicit model control"); + } + + std::unique_ptr life_cycle; + RETURN_IF_ERROR( + ModelLifeCycle::Create(server, life_cycle_options, &life_cycle)); + + // Not setting the smart pointer directly to simplify clean up + std::unique_ptr local_manager( + new ModelRepositoryManager( + repository_paths, !strict_model_config, polling_enabled, + model_control_enabled, life_cycle_options.min_compute_capability_, + std::move(life_cycle))); + *model_repository_manager = std::move(local_manager); + + // Support loading all models on startup in explicit model control mode with + // special startup_model name "*". This does not imply support for pattern + // matching in model names. + bool load_all_models_on_startup = false; + if ((startup_models.find("*") != startup_models.end()) && + model_control_enabled) { + if (startup_models.size() > 1) { + return Status( + Status::Code::INVALID_ARG, + "Wildcard model name '*' must be the ONLY startup model " + "if specified at all."); + } + + load_all_models_on_startup = true; + } + + bool all_models_polled = true; + if (!model_control_enabled || load_all_models_on_startup) { + // only error happens before model load / unload will be return + // model loading / unloading error will be printed but ignored + RETURN_IF_ERROR( + (*model_repository_manager)->PollAndUpdateInternal(&all_models_polled)); + } else { + // Load each specified startup_model + std::unordered_map> + models; + for (const auto& model_name : startup_models) { + models[model_name]; + } + RETURN_IF_ERROR( + (*model_repository_manager) + ->LoadUnloadModels( + models, ActionType::LOAD, false, &all_models_polled)); + } + + + if (!all_models_polled) { + return Status(Status::Code::INTERNAL, "failed to load all models"); + } + // Some models may failed to be loaded after model manager is created, + // return proper error and let function caller decide whether to proceed. + for (const auto& model : (*model_repository_manager)->infos_) { + const auto version_states = + (*model_repository_manager) + ->model_life_cycle_->VersionStates(model.first); + // Return general error message, detail of each model's loading state + // is logged separately. + if (version_states.empty()) { + return Status(Status::Code::INTERNAL, "failed to load all models"); + } + for (const auto& state : version_states) { + if (state.second.first != ModelReadyState::READY) { + return Status(Status::Code::INTERNAL, "failed to load all models"); + } + } + } + + return Status::Success; +} + +Status +ModelRepositoryManager::PollAndUpdate() +{ + if (!polling_enabled_) { + return Status(Status::Code::UNAVAILABLE, "polling is disabled"); + } + + bool all_models_polled; + return PollAndUpdateInternal(&all_models_polled); +} + +Status +ModelRepositoryManager::PollAndUpdateInternal(bool* all_models_polled) +{ + // Serialize all operations that change model state + std::lock_guard lock(poll_mu_); + + std::set added, deleted, modified, unmodified; + + // We don't modify 'infos_' in place to minimize how long we need to + // hold the lock and also prevent any partial changes to do an error + // during processing. + ModelInfoMap new_infos; + + // Each subdirectory of repository path is a model directory from + // which we read the model configuration. + std::unordered_map> + subdirs; + RETURN_IF_ERROR(Poll( + subdirs, &added, &deleted, &modified, &unmodified, &new_infos, + all_models_polled)); + + // Anything in 'infos_' that is not in "added", "modified", or + // "unmodified" is deleted. + for (const auto& pr : infos_) { + if ((added.find(pr.first) == added.end()) && + (modified.find(pr.first) == modified.end()) && + (unmodified.find(pr.first) == unmodified.end())) { + deleted.insert(pr.first); + } + } + + // Nothing to do if no model adds, deletes or modifies. + if (added.empty() && deleted.empty() && modified.empty()) { + return Status::Success; + } + + infos_.swap(new_infos); + + UpdateDependencyGraph(added, deleted, modified); + + for (const auto& name : deleted) { + model_life_cycle_->AsyncUnload(name); + } + + // model loading / unloading error will be printed but ignored + LoadModelByDependency(); + + return Status::Success; +} + +std::map +ModelRepositoryManager::LoadModelByDependency() +{ + std::map res; + struct ModelState { + ModelState(DependencyNode* node) : node_(node), status_(Status::Success) {} + DependencyNode* node_; + Status status_; + std::promise ready_; + }; + NodeSet loaded_models; + auto set_pair = ModelsToLoadUnload(loaded_models); + // Loop until all model are loaded / unloaded + while ((!set_pair.first.empty()) || (!set_pair.second.empty())) { + loaded_models.clear(); + // Unload invalid models first + for (auto& invalid_model : set_pair.second) { + model_life_cycle_->AsyncUnload(invalid_model->model_name_); + LOG_ERROR << invalid_model->status_.AsString(); + invalid_model->loaded_versions_ = std::set(); + loaded_models.emplace(invalid_model); + } + // load valid models and wait for load results + std::vector> model_states; + for (auto& valid_model : set_pair.first) { + model_states.emplace_back(new ModelState(valid_model)); + auto model_state = model_states.back().get(); + const auto itr = infos_.find(valid_model->model_name_); + auto status = model_life_cycle_->AsyncLoad( + valid_model->model_name_, itr->second->model_path_, + valid_model->model_config_, itr->second->is_config_provided_, + itr->second->agent_model_list_, [model_state](Status load_status) { + model_state->status_ = load_status; + model_state->ready_.set_value(); + }); + if (!status.IsOk()) { + model_state->status_ = status; + model_state->ready_.set_value(); + LOG_ERROR << "failed to load model '" << valid_model->model_name_ + << "': " << status.Message(); + } + loaded_models.emplace(valid_model); + } + for (auto& model_state : model_states) { + model_state->ready_.get_future().wait(); + res[model_state->node_->model_name_] = model_state->status_; + const auto version_state = + model_life_cycle_->VersionStates(model_state->node_->model_name_); + model_state->node_->loaded_versions_.clear(); + for (const auto& vs : version_state) { + if (vs.second.first == ModelReadyState::READY) { + model_state->node_->loaded_versions_.emplace(vs.first); + } + } + // If the model failed to load, should revert the timestamp to + // ensure the next load request will attempt to load the model again + // for operation consistency. + if (!model_state->status_.IsOk()) { + auto& model_info = infos_.find(model_state->node_->model_name_)->second; + model_info->mtime_nsec_ = model_info->prev_mtime_ns_; + } + } + set_pair = ModelsToLoadUnload(loaded_models); + } + // Clear temporary stored agent model list after all loads are triggerred + for (auto& info : infos_) { + info.second->agent_model_list_.reset(); + } + return res; +} + +Status +ModelRepositoryManager::LoadUnloadModel( + const std::unordered_map< + std::string, std::vector>& models, + const ActionType type, const bool unload_dependents) +{ + if (!model_control_enabled_) { + return Status( + Status::Code::UNAVAILABLE, + "explicit model load / unload is not allowed if polling is enabled"); + } + + if (models.size() > 1) { + return Status( + Status::Code::UNSUPPORTED, + "explicit load / unload multiple models is not currently supported"); + } + + // Serialize all operations that change model state + std::lock_guard lock(poll_mu_); + + bool polled = true; + RETURN_IF_ERROR(LoadUnloadModels(models, type, unload_dependents, &polled)); + // Check if model is loaded / unloaded properly + const auto& model_name = models.begin()->first; + if (!polled) { + return Status( + Status::Code::INTERNAL, "failed to load '" + model_name + + "', failed to poll from model repository"); + } + + const auto version_states = model_life_cycle_->VersionStates(model_name); + if (type == ActionType::LOAD) { + if (version_states.empty()) { + return Status( + Status::Code::INTERNAL, + "failed to load '" + model_name + "', no version is available"); + } + auto it = infos_.find(model_name); + if (it == infos_.end()) { + return Status( + Status::Code::INTERNAL, + "failed to load '" + model_name + + "', failed to poll from model repository"); + } + } else { + std::string ready_version_str; + for (const auto& version_state : version_states) { + if (version_state.second.first == ModelReadyState::READY) { + ready_version_str += std::to_string(version_state.first); + ready_version_str += ","; + } + } + if (!ready_version_str.empty()) { + ready_version_str.pop_back(); + return Status( + Status::Code::INTERNAL, + "failed to unload '" + model_name + + "', versions that are still available: " + ready_version_str); + } + } + + return Status::Success; +} + +Status +ModelRepositoryManager::LoadUnloadModels( + const std::unordered_map< + std::string, std::vector>& models, + const ActionType type, const bool unload_dependents, + bool* all_models_polled) +{ + auto status = Status::Success; + *all_models_polled = true; + // Update ModelInfo related to file system accordingly + std::set added, deleted, modified, unmodified; + { + if (type == ActionType::UNLOAD) { + for (const auto& model : models) { + deleted.insert(model.first); + } + } + // ActionType::LOAD and in model control mode + else { + std::set checked_models; + auto current_models = models; + for (const auto& model : models) { + checked_models.emplace(model.first); + } + + ModelInfoMap new_infos; +#ifdef TRITON_ENABLE_ENSEMBLE + bool first_iteration = true; +#endif // TRITON_ENABLE_ENSEMBLE + while (!current_models.empty()) { + bool polled = true; + RETURN_IF_ERROR(Poll( + current_models, &added, &deleted, &modified, &unmodified, + &new_infos, &polled)); + *all_models_polled &= polled; + + // More models should be polled if the polled models are ensembles + std::unordered_map> + next_models; +#ifdef TRITON_ENABLE_ENSEMBLE + for (const auto& model : current_models) { + auto it = new_infos.find(model.first); + // Some models may be marked as deleted and not in 'new_infos' + if (it != new_infos.end()) { + it->second->explicitly_load_ = first_iteration; + const auto& config = it->second->model_config_; + if (config.has_ensemble_scheduling()) { + for (const auto& step : config.ensemble_scheduling().step()) { + bool need_poll = + checked_models.emplace(step.model_name()).second; + if (need_poll) { + next_models[step.model_name()]; + } + } + } + } + } + first_iteration = false; +#endif // TRITON_ENABLE_ENSEMBLE + current_models.swap(next_models); + } + + // Only update the infos when all validation is completed + for (const auto& model_name : added) { + auto nitr = new_infos.find(model_name); + infos_.emplace(model_name, std::move(nitr->second)); + } + for (const auto& model_name : modified) { + auto nitr = new_infos.find(model_name); + auto itr = infos_.find(model_name); + itr->second = std::move(nitr->second); + } + } + } + std::set deleted_dependents; + + // Update dependency graph and load + UpdateDependencyGraph( + added, deleted, modified, + unload_dependents ? &deleted_dependents : nullptr); + + // The models are in 'deleted' either when they are asked to be unloaded or + // they are not found / are duplicated across all model repositories. + // In all cases, should unload them and remove from 'infos_' explicitly. + for (const auto& name : (unload_dependents ? deleted_dependents : deleted)) { + infos_.erase(name); + model_life_cycle_->AsyncUnload(name); + } + + // load / unload the models affected, and check the load status of + // the requested models + const auto& load_status = LoadModelByDependency(); + if (status.IsOk() && (type == ActionType::LOAD)) { + std::string load_error_message = ""; + for (const auto& model : models) { + auto it = load_status.find(model.first); + // If 'model.first' not in load status, it means the (re-)load is not + // necessary because there is no change in the model's directory + if ((it != load_status.end()) && !it->second.IsOk()) { + load_error_message += + ("load failed for model '" + model.first + + "': " + it->second.Message() + "\n"); + } + } + if (!load_error_message.empty()) { + status = Status(Status::Code::INVALID_ARG, load_error_message); + } + } + + return status; +} + +Status +ModelRepositoryManager::UnloadAllModels() +{ + Status status; + for (const auto& name_info : infos_) { + Status unload_status = model_life_cycle_->AsyncUnload(name_info.first); + if (!unload_status.IsOk()) { + status = Status( + unload_status.ErrorCode(), + "Failed to gracefully unload models: " + unload_status.Message()); + } + } + return Status::Success; +} + +Status +ModelRepositoryManager::StopAllModels() +{ + return model_life_cycle_->StopAllModels(); +} + +const std::set> +ModelRepositoryManager::InflightStatus() +{ + return model_life_cycle_->InflightStatus(); +} + +const ModelStateMap +ModelRepositoryManager::LiveModelStates(bool strict_readiness) +{ + return model_life_cycle_->LiveModelStates(strict_readiness); +} + +const ModelStateMap +ModelRepositoryManager::ModelStates() +{ + return model_life_cycle_->ModelStates(); +} + +const VersionStateMap +ModelRepositoryManager::VersionStates(const std::string& model_name) +{ + return model_life_cycle_->VersionStates(model_name); +} + +Status +ModelRepositoryManager::ModelState( + const std::string& model_name, const int64_t model_version, + ModelReadyState* state) +{ + return model_life_cycle_->ModelState(model_name, model_version, state); +} + +Status +ModelRepositoryManager::RepositoryIndex( + const bool ready_only, std::vector* index) +{ + std::set seen_models; + std::set duplicate_models; + for (const auto& repository_path : repository_paths_) { + // For any mapped models in this repository, save the mapping + // from their subdirectory name to model name. + std::map models_in_repo; + for (const auto& mapping_it : model_mappings_) { + if (mapping_it.second.first == repository_path) { + models_in_repo.emplace( + BaseName(mapping_it.second.second), mapping_it.first); + } + } + std::set subdirs; + RETURN_IF_ERROR(GetDirectorySubdirs(repository_path, &subdirs)); + for (const auto& subdir : subdirs) { + auto model = subdir; + auto model_it = models_in_repo.find(subdir); + if (model_it != models_in_repo.end()) { + model = model_it->second; + } + + if (seen_models.find(model) != seen_models.end()) { + duplicate_models.insert(model); + } + + seen_models.insert(model); + } + } + + ModelStateMap states = ModelStates(); + + for (const auto& model : seen_models) { + // If the same model appears in multiple repostories then show it + // as unavailable since duplicate models are not allowed to load. + if (duplicate_models.find(model) != duplicate_models.end()) { + index->emplace_back( + model, -1 /* version */, ModelReadyState::UNAVAILABLE, + MODEL_READY_REASON_DUPLICATE); + continue; + } + + // If there is any version/state/reason associated with the model + // then include that in the index. + auto sitr = states.find(model); + if (sitr == states.end()) { + if (!ready_only) { + index->emplace_back(model); + } + } else { + for (const auto& pr : sitr->second) { + if (!ready_only || (pr.second.first == ModelReadyState::READY)) { + index->emplace_back( + model, pr.first, pr.second.first, pr.second.second); + } + } + } + } + + return Status::Success; +} + +Status +ModelRepositoryManager::GetModel( + const std::string& model_name, const int64_t model_version, + std::shared_ptr* model) +{ + Status status = model_life_cycle_->GetModel(model_name, model_version, model); + if (!status.IsOk()) { + model->reset(); + status = Status( + status.ErrorCode(), "Request for unknown model: " + status.Message()); + } + return status; +} + +Status +ModelRepositoryManager::Poll( + const std::unordered_map< + std::string, std::vector>& models, + std::set* added, std::set* deleted, + std::set* modified, std::set* unmodified, + ModelInfoMap* updated_infos, bool* all_models_polled) +{ + *all_models_polled = true; + // empty path is the special case to indicate the model should be loaded + // from override file content in 'models'. + std::map model_to_path; + + // If no model is specified, poll all models in all model repositories. + // Otherwise, only poll the specified models + if (models.empty()) { + std::set duplicated_models; + for (const auto& repository_path : repository_paths_) { + std::set subdirs; + Status status = GetDirectorySubdirs(repository_path, &subdirs); + if (!status.IsOk()) { + LOG_ERROR << "failed to poll model repository '" << repository_path + << "': " << status.Message(); + *all_models_polled = false; + } else { + for (const auto& subdir : subdirs) { + if (!model_to_path + .emplace(subdir, JoinPath({repository_path, subdir})) + .second) { + duplicated_models.insert(subdir); + *all_models_polled = false; + } + } + } + } + // If the model is not unique, mark as deleted to unload it + for (const auto& model : duplicated_models) { + model_to_path.erase(model); + deleted->insert(model); + LOG_ERROR << "failed to poll model '" << model + << "': not unique across all model repositories"; + } + } + // If models are specified, this is explicit model control mode. + else { + for (const auto& model : models) { + // Skip repository polling if override model files + if (ModelDirectoryOverride(model.second)) { + model_to_path.emplace(model.first, ""); + continue; + } + // Check model mapping first to see if matching model to load. + bool exists = false; + auto model_it = model_mappings_.find(model.first); + if (model_it != model_mappings_.end()) { + bool exists_in_this_repo = false; + auto full_path = model_it->second.second; + Status status = FileExists(full_path, &exists_in_this_repo); + if (!status.IsOk()) { + LOG_ERROR << "failed to poll mapped path '" << full_path + << "' for model '" << model.first + << "': " << status.Message(); + *all_models_polled = false; + } + if (exists_in_this_repo) { + model_to_path.emplace(model.first, model_it->second.second); + exists = true; + } else { + LOG_ERROR << "mapped path '" << full_path + << "' does not exist for model '" << model.first << "'"; + exists = false; + } + } else { + for (const auto repository_path : repository_paths_) { + bool exists_in_this_repo = false; + const auto full_path = JoinPath({repository_path, model.first}); + Status status = FileExists(full_path, &exists_in_this_repo); + if (!status.IsOk()) { + LOG_ERROR << "failed to poll model repository '" << repository_path + << "' for model '" << model.first + << "': " << status.Message(); + *all_models_polled = false; + } else if (exists_in_this_repo) { + // Check to make sure this directory is not mapped. + // If mapped, continue to next repository path. + bool mapped = false; + for (auto const& mapping : model_mappings_) { + if (mapping.second.second == full_path) { + mapped = true; + break; + } + } + if (mapped) { + continue; + } + + auto res = model_to_path.emplace( + model.first, JoinPath({repository_path, model.first})); + if (res.second) { + exists = true; + } else { + exists = false; + model_to_path.erase(res.first); + LOG_ERROR << "failed to poll model '" << model.first + << "': not unique across all model repositories"; + break; + } + } + } + } + // For an explicitly specified model that doesn't exist, we don't mark it + // as deleted, we simply mark that we couldn't poll all models. + if (!exists) { + *all_models_polled = false; + } + } + } + + // Poll each of the models. If error happens during polling the model, + // its state will fallback to the state before the polling. + for (const auto& pair : model_to_path) { + std::unique_ptr model_info; + const auto& mit = models.find(pair.first); + static std::vector empty_params; + auto status = InitializeModelInfo( + pair.first, pair.second, + ((mit == models.end()) ? empty_params : mit->second), &model_info); + + const auto& iitr = infos_.find(pair.first); + const bool invalid_add = (!status.IsOk()) && (iitr == infos_.end()); + if (!invalid_add) { + const auto& ret = updated_infos->emplace(pair.first, nullptr); + if (!ret.second) { + return Status( + Status::Code::ALREADY_EXISTS, + "unexpected model info for model '" + pair.first + "'"); + } + + // Classify load state and set updated info + if (model_info == nullptr) { + ret.first->second.reset(new ModelInfo(*iitr->second)); + unmodified->insert(pair.first); + } else { + ret.first->second = std::move(model_info); + if (iitr != infos_.end()) { + modified->insert(pair.first); + } else { + added->insert(pair.first); + } + } + } + + if (!status.IsOk()) { + LOG_ERROR << "Poll failed for model directory '" << pair.first + << "': " << status.Message(); + *all_models_polled = false; + } + } + + return Status::Success; +} + +bool +ModelRepositoryManager::ModelDirectoryOverride( + const std::vector& model_params) +{ + for (const auto& param : model_params) { + if (param->Name().rfind(file_prefix, 0) == 0) { + // param name starts with prefix if user provides override file + return true; + } + } + return false; +} + +Status +ModelRepositoryManager::InitializeModelInfo( + const std::string& name, const std::string& path, + const std::vector& params, + std::unique_ptr* info) +{ + std::unique_ptr linfo(new ModelInfo()); + linfo->model_path_ = path; + + bool unmodified = false; + + const auto iitr = infos_.find(name); + // Set 'prev_mtime_ns_' if there is existing ModelInfo + if (iitr != infos_.end()) { + linfo->prev_mtime_ns_ = iitr->second->mtime_nsec_; + } else { + linfo->prev_mtime_ns_ = 0; + } + + // Set 'mtime_nsec_' and override 'model_path_' if current path is empty + // (file override is specified) + if (linfo->model_path_.empty()) { + // Need to localize the override files, use repo agent to manage + // the lifecycle of the localized files + std::shared_ptr localize_agent(new LocalizeRepoAgent()); + std::unique_ptr localize_agent_model; + RETURN_IF_ERROR(TritonRepoAgentModel::Create( + TRITONREPOAGENT_ARTIFACT_FILESYSTEM, "", inference::ModelConfig(), + localize_agent, {}, &localize_agent_model)); + + // Set agent model state so the repo agent can access the encoded files + // Using const_cast here but we are safe as the RepoAgent will not + // modify the state + localize_agent_model->SetState( + const_cast(reinterpret_cast(¶ms))); + RETURN_IF_ERROR( + localize_agent_model->InvokeAgent(TRITONREPOAGENT_ACTION_LOAD)); + + const char* location; + TRITONREPOAGENT_ArtifactType type; + RETURN_IF_ERROR(localize_agent_model->Location(&type, &location)); + + // For file override, set 'mtime_nsec_' to minimum value so that + // the next load without override will trigger re-load to undo + // the override while the local files may still be unchanged. + linfo->mtime_nsec_ = 0; + linfo->model_path_ = location; + linfo->agent_model_list_.reset(new TritonRepoAgentModelList()); + linfo->agent_model_list_->AddAgentModel(std::move(localize_agent_model)); + } else { + if (iitr == infos_.end()) { + linfo->mtime_nsec_ = GetModifiedTime(std::string(linfo->model_path_)); + } else { + // Check the current timestamps to determine if model actually has been + // modified + linfo->mtime_nsec_ = linfo->prev_mtime_ns_; + unmodified = + !IsModified(std::string(linfo->model_path_), &linfo->mtime_nsec_); + } + } + + // Set 'model_config_' + bool parsed_config = false; + // Check if there is config override + for (const auto& override_parameter : params) { + if ((override_parameter->Name() == "config") && + (override_parameter->Type() == TRITONSERVER_PARAMETER_STRING)) { + // When override happens, set 'mtime_nsec_' to minimum value so that + // the next load without override will trigger re-load to undo + // the override while the local files may still be unchanged. + linfo->mtime_nsec_ = 0; + unmodified = false; + + const std::string& override_config = override_parameter->ValueString(); + auto err = JsonToModelConfig( + override_config, 1 /* config_version */, &linfo->model_config_); + if (!err.IsOk()) { + return Status( + Status::Code::INVALID_ARG, + "Invalid config override: " + std::string(err.Message())); + } + parsed_config = true; + break; + } else if (override_parameter->Name().rfind(file_prefix, 0) != 0) { + return Status( + Status::Code::INVALID_ARG, + "Unrecognized load parameter '" + override_parameter->Name() + + "' with type '" + + TRITONSERVER_ParameterTypeString(override_parameter->Type()) + + "'"); + } + } + + // Polling model is considered unmodified by this point and can be returned + // with info == nullptr + if (unmodified) { + return Status::Success; + } + + // Create the associated repo agent models when a model is to be loaded, + // this must be done before normalizing model config as agents might + // redirect to use the model config at a different location + if (!parsed_config) { + const auto config_path = JoinPath({linfo->model_path_, kModelConfigPbTxt}); + bool model_config_exists = false; + RETURN_IF_ERROR(FileExists(config_path, &model_config_exists)); + // model config can be missing if auto fill is set + if (autofill_ && !model_config_exists) { + linfo->model_config_.Clear(); + } else { + RETURN_IF_ERROR(ReadTextProto(config_path, &linfo->model_config_)); + parsed_config = true; + } + } + if (parsed_config) { + RETURN_IF_ERROR(CreateAgentModelListWithLoadAction( + linfo->model_config_, linfo->model_path_, &linfo->agent_model_list_)); + if (linfo->agent_model_list_ != nullptr) { + // Get the latest repository path + const char* location; + TRITONREPOAGENT_ArtifactType artifact_type; + RETURN_IF_ERROR(linfo->agent_model_list_->Back()->Location( + &artifact_type, &location)); + auto latest_path = std::string(location); + linfo->model_path_ = latest_path; + } + } + linfo->is_config_provided_ = parsed_config; + + // Try to automatically generate missing parts of the model + // configuration (autofill) that don't require model detail + RETURN_IF_ERROR(GetNormalizedModelConfig( + name, linfo->model_path_, min_compute_capability_, + &linfo->model_config_)); + + // Note that the model inputs and outputs are not validated until + // the model model is intialized as they may not be auto-completed + // until model is intialized. + RETURN_IF_ERROR( + ValidateModelConfig(linfo->model_config_, min_compute_capability_)); + if (!autofill_) { + RETURN_IF_ERROR(ValidateModelIOConfig(linfo->model_config_)); + } + + // If the model is mapped, update its config name based on the + // mapping. + if (model_mappings_.find(name) != model_mappings_.end()) { + linfo->model_config_.set_name(name); + } else { + // If there is no model mapping, make sure the name of the model + // matches the name of the directory. This is a somewhat arbitrary + // requirement but seems like good practice to require it of the user. + // It also acts as a check to make sure we don't have two different + // models with the same name. + if (linfo->model_config_.name() != name) { + return Status( + Status::Code::INVALID_ARG, + "unexpected directory name '" + name + "' for model '" + + linfo->model_config_.name() + + "', directory name must equal model name"); + } + } + + *info = std::move(linfo); + return Status::Success; +} + +Status +ModelRepositoryManager::UpdateDependencyGraph( + const std::set& added, const std::set& deleted, + const std::set& modified, + std::set* deleted_dependents) +{ + // update dependency graph, if the state of a node is changed, all its + // downstreams will be affected + + // deleted, drop from dependency_graph, add to missing_nodes if downstreams is + // not empty affected_nodes are all ensembles as only ensembles are depending + // on other models + std::set affected_nodes; + std::set updated_nodes; + std::set current_deleted = deleted; + while (!current_deleted.empty()) { + std::set next_deleted; + for (const auto& model_name : current_deleted) { + auto it = dependency_graph_.find(model_name); + if (it != dependency_graph_.end()) { + // remove this node from its upstreams + for (auto& upstream : it->second->upstreams_) { + upstream.first->downstreams_.erase(it->second.get()); + // Check if the upstream should be removed as well + if ((deleted_dependents != nullptr) && + (upstream.first->downstreams_.empty()) && + (!upstream.first->explicitly_load_)) { + next_deleted.emplace(upstream.first->model_name_); + } + } + it->second->upstreams_.clear(); + + if (!it->second->downstreams_.empty()) { + UncheckDownstream(&it->second->downstreams_, &affected_nodes); + // mark this node as missing upstream in its downstreams + for (auto& downstream : it->second->downstreams_) { + downstream->missing_upstreams_.emplace(it->second.get()); + } + missing_nodes_.emplace( + std::make_pair(model_name, std::move(it->second))); + } + + // Make sure deleted node will not be in affected nodes + affected_nodes.erase(it->second.get()); + dependency_graph_.erase(it); + } + if (deleted_dependents != nullptr) { + deleted_dependents->emplace(model_name); + } + } + current_deleted.swap(next_deleted); + } + + // modified, invalidate (uncheck) all downstreams + for (const auto& model_name : modified) { + auto it = dependency_graph_.find(model_name); + if (it != dependency_graph_.end()) { + UncheckDownstream(&it->second->downstreams_, &affected_nodes); + ModelInfo* info = nullptr; + GetModelInfo(model_name, &info); + it->second->model_config_ = info->model_config_; + it->second->explicitly_load_ = info->explicitly_load_; + // remove this node from its upstream node + for (auto& upstream : it->second->upstreams_) { + upstream.first->downstreams_.erase(it->second.get()); + } + it->second->upstreams_.clear(); + it->second->checked_ = false; + it->second->status_ = Status::Success; + updated_nodes.emplace(it->second.get()); + } + } + + // added, add to dependency_graph, if in missing_node, invalidate (uncheck) + // and associate all downstreams, remove from missing_node + for (const auto& model_name : added) { + std::unique_ptr added_node; + auto it = missing_nodes_.find(model_name); + if (it != missing_nodes_.end()) { + UncheckDownstream(&it->second->downstreams_, &affected_nodes); + // remove this node from missing upstream node in its downstream nodes + for (auto& downstream : it->second->downstreams_) { + downstream->missing_upstreams_.erase(it->second.get()); + } + + it->second->checked_ = false; + added_node = std::move(it->second); + missing_nodes_.erase(it); + } else { + // Right now, nothing is going to be filled until validation + added_node.reset(new DependencyNode(model_name)); + } + ModelInfo* info = nullptr; + GetModelInfo(model_name, &info); + added_node->model_config_ = info->model_config_; + added_node->explicitly_load_ = info->explicitly_load_; + updated_nodes.emplace(added_node.get()); + dependency_graph_.emplace( + std::make_pair(model_name, std::move(added_node))); + } + + auto& affected_ensembles = affected_nodes; + for (auto& updated_node : updated_nodes) { + bool is_ensemble = ConnectDependencyGraph(updated_node); + if (is_ensemble) { + affected_ensembles.emplace(updated_node); + } + } + +#ifdef TRITON_ENABLE_ENSEMBLE + // After the dependency graph is updated, check ensemble dependencies + for (auto& ensemble : affected_ensembles) { + if (ensemble->status_.IsOk()) { + if (!ensemble->missing_upstreams_.empty()) { + std::string name_list; + for (auto it = ensemble->missing_upstreams_.begin(); + it != ensemble->missing_upstreams_.end(); it++) { + if (it != ensemble->missing_upstreams_.begin()) { + name_list += ", "; + } + name_list += (*it)->model_name_; + } + ensemble->status_ = Status( + Status::Code::INVALID_ARG, + "ensemble " + ensemble->model_name_ + + " contains models that are not available: " + name_list); + } else { + ensemble->status_ = CircularcyCheck(ensemble, ensemble); + } + } + } +#endif // TRITON_ENABLE_ENSEMBLE + return Status::Success; +} + +Status +ModelRepositoryManager::RegisterModelRepository( + const std::string& repository, + const std::unordered_map& model_mapping) +{ + if (!model_control_enabled_) { + return Status( + Status::Code::UNSUPPORTED, + "repository registration is not allowed if model control mode is not " + "EXPLICIT"); + } + bool is_directory = false; + auto status = IsDirectory(repository, &is_directory); + if (!status.IsOk() || !is_directory) { + return Status( + Status::Code::INVALID_ARG, (std::string("failed to register '") + + repository + "', repository not found") + .c_str()); + } + + { + // Serialize all operations that change model state + std::lock_guard lock(poll_mu_); + + // Check repository and mapped models do not yet exist. + if (repository_paths_.find(repository) != repository_paths_.end()) { + return Status( + Status::Code::ALREADY_EXISTS, + "model repository '" + repository + "' has already been registered"); + } + + for (const auto& mapping : model_mapping) { + if (model_mappings_.find(mapping.first) != model_mappings_.end()) { + return Status( + Status::Code::ALREADY_EXISTS, + (std::string("failed to register '") + mapping.first + + "', there is a conflicting mapping for '" + + std::string(mapping.first) + "'") + .c_str()); + } + } + + repository_paths_.emplace(repository); + for (const auto& mapping : model_mapping) { + model_mappings_.emplace( + mapping.first, + std::make_pair(repository, JoinPath({repository, mapping.second}))); + } + } + + LOG_INFO << "Model repository registered: " << repository; + return Status::Success; +} + +Status +ModelRepositoryManager::UnregisterModelRepository(const std::string& repository) +{ + if (!model_control_enabled_) { + return Status( + Status::Code::UNSUPPORTED, + "repository unregistration is not allowed if model control mode is not " + "EXPLICIT"); + } + { + std::lock_guard lock(poll_mu_); + if (repository_paths_.erase(repository) != 1) { + return Status( + Status::Code::INVALID_ARG, + "failed to unregister '" + repository + "', repository not found"); + } + + std::set models_to_delete; + for (auto const& mapping : model_mappings_) { + if (mapping.second.first == repository) { + models_to_delete.insert(mapping.first); + } + } + for (auto const& model : models_to_delete) { + model_mappings_.erase(model); + } + } + + LOG_INFO << "Model repository unregistered: " << repository; + return Status::Success; +} + +Status +ModelRepositoryManager::CircularcyCheck( + DependencyNode* current_node, const DependencyNode* start_node) +{ + for (auto& downstream : current_node->downstreams_) { + if (downstream->model_name_ == start_node->model_name_) { + return Status( + Status::Code::INVALID_ARG, + "circular dependency between ensembles: " + start_node->model_name_ + + " -> ... -> " + current_node->model_name_ + " -> " + + start_node->model_name_); + } else { + const auto status = CircularcyCheck(downstream, start_node); + if (!status.IsOk() && current_node->status_.IsOk()) { + current_node->status_ = status; + return status; + } + } + } + return Status::Success; +} + +void +ModelRepositoryManager::UncheckDownstream( + NodeSet* downstreams, NodeSet* updated_nodes) +{ + // Mark downstream nodes as unchecked recursively + for (auto& node : *downstreams) { + if (node->checked_) { + node->checked_ = false; + node->status_ = Status::Success; + UncheckDownstream(&node->downstreams_, updated_nodes); + updated_nodes->emplace(node); + } + } +} + +bool +ModelRepositoryManager::ConnectDependencyGraph(DependencyNode* updated_node) +{ + // Check the node's model config to determine if it depends on other models + // and if those models are present + updated_node->upstreams_.clear(); + updated_node->missing_upstreams_.clear(); + if (updated_node->model_config_.has_ensemble_scheduling()) { + for (const auto& step : + updated_node->model_config_.ensemble_scheduling().step()) { + DependencyNode* upstream_node = nullptr; + const auto& model_name = step.model_name(); + auto dit = dependency_graph_.find(model_name); + if (dit == dependency_graph_.end()) { + auto mit = missing_nodes_.find(model_name); + if (mit == missing_nodes_.end()) { + std::unique_ptr node(new DependencyNode(model_name)); + updated_node->missing_upstreams_.emplace(node.get()); + mit = missing_nodes_.emplace(model_name, std::move(node)).first; + } + // Add the node to missing node's downstream so that when the missing + // node is added, the downstreams can be found easily. + mit->second->downstreams_.emplace(updated_node); + upstream_node = mit->second.get(); + } else { + dit->second->downstreams_.emplace(updated_node); + upstream_node = dit->second.get(); + } + auto res = updated_node->upstreams_.emplace( + upstream_node, std::set({step.model_version()})); + // If map insertion doesn't happen, the same model is required in + // different step, insert the version to existing required version set. + if (!res.second) { + res.first->second.insert(step.model_version()); + } + } + return true; + } + return false; +} + +Status +ModelRepositoryManager::GetModelInfo( + const std::string& name, ModelInfo** model_info) +{ + const auto itr = infos_.find(name); + if (itr == infos_.end()) { + return Status( + Status::Code::NOT_FOUND, "no configuration for model '" + name + "'"); + } + + *model_info = itr->second.get(); + return Status::Success; +} + +std::pair +ModelRepositoryManager::ModelsToLoadUnload(const NodeSet& loaded_models) +{ + // + std::pair res; + // first call to this function + if (loaded_models.empty()) { + for (auto& pair : dependency_graph_) { + auto node = pair.second.get(); + // only care about nodes that are affected by the update + if (!node->checked_) { + if (CheckNode(node)) { + if (node->status_.IsOk()) { + res.first.emplace(node); + } else { + res.second.emplace(node); + } + } + } + } + } else { + for (const auto& model : loaded_models) { + for (auto node : model->downstreams_) { + // only care about nodes that are affected by the update + if (!node->checked_) { + if (CheckNode(node)) { + if (node->status_.IsOk()) { + res.first.emplace(node); + } else { + res.second.emplace(node); + } + } + } + } + } + } + for (auto& node : res.first) { + node->checked_ = true; + } + for (auto& node : res.second) { + node->checked_ = true; + } + return res; +} + +bool +ModelRepositoryManager::CheckNode(DependencyNode* node) +{ + bool node_ready = true; + // if the node is in invalid status, mark as ready as we know + // it should not be loaded + if (node->status_.IsOk()) { + for (auto& upstream : node->upstreams_) { + if (!upstream.first->checked_) { + node_ready = false; + break; + } + if (!upstream.first->status_.IsOk()) { + node->status_ = Status( + Status::Code::INVALID_ARG, + "ensemble '" + node->model_name_ + "' depends on '" + + upstream.first->model_name_ + "' which is not valid"); + } else if (upstream.first->loaded_versions_.empty()) { + node->status_ = Status( + Status::Code::INVALID_ARG, + "ensemble '" + node->model_name_ + "' depends on '" + + upstream.first->model_name_ + "' which has no loaded version"); + } else { + for (const auto& required_version : upstream.second) { + if (required_version == -1) { + continue; + } + + auto it = upstream.first->loaded_versions_.find(required_version); + if (it == upstream.first->loaded_versions_.end()) { + node->status_ = Status( + Status::Code::INVALID_ARG, + "ensemble '" + node->model_name_ + "' depends on '" + + upstream.first->model_name_ + "' whose required version " + + std::to_string(required_version) + " is not loaded"); + } + } + } + if (!node->status_.IsOk()) { + break; + } + } +#ifdef TRITON_ENABLE_ENSEMBLE + // Validate ensemble config if the node is ready. By this point, the + // depending models are loaded and their configs are completed + if (node_ready && node->status_.IsOk()) { + node->status_ = ValidateEnsembleConfig(this, node); + } +#endif // TRITON_ENABLE_ENSEMBLE + } + return node_ready; +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/model_repository_manager.h b/3rdparty/core-r22.12/src/model_repository_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..bd06723801da4dd5e01bd800f1545fecb1a26e25 --- /dev/null +++ b/3rdparty/core-r22.12/src/model_repository_manager.h @@ -0,0 +1,345 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +#pragma once + +#include +#include +#include +#include +#include "infer_parameter.h" +#include "model_config.pb.h" +#include "model_lifecycle.h" +#include "status.h" +#include "triton/common/model_config.h" + +namespace triton { namespace core { + +class InferenceServer; +class Model; + +// [FIXME] should have separated load / unload functions for clarity +enum ActionType { NO_ACTION, LOAD, UNLOAD }; + +/// Predefined reason strings +#define MODEL_READY_REASON_DUPLICATE "model appears in two or more repositories" + +/// An object to manage the model repository active in the server. +class ModelRepositoryManager { + public: + // Index information for a model. + struct ModelIndex { + ModelIndex(const std::string& n) + : name_only_(true), name_(n), version_(-1), + state_(ModelReadyState::UNKNOWN) + { + } + ModelIndex( + const std::string& n, const int64_t v, const ModelReadyState s, + const std::string& r) + : name_only_(false), name_(n), version_(v), state_(s), reason_(r) + { + } + const bool name_only_; + const std::string name_; + const int64_t version_; + const ModelReadyState state_; + const std::string reason_; + }; + + /// A basic unit in dependency graph that records the models seen by the model + /// repository manager. + struct DependencyNode { + DependencyNode(const std::string& model_name) + : model_name_(model_name), status_(Status::Success), checked_(false) + { + } + + std::string model_name_; + Status status_; + bool checked_; + bool explicitly_load_; + inference::ModelConfig model_config_; + std::set loaded_versions_; + std::set missing_upstreams_; + std::unordered_map> upstreams_; + std::set downstreams_; + }; + + ~ModelRepositoryManager(); + + /// Create a manager for a repository. + /// \param server The pointer to the inference server. + /// \param server_version The version of the inference server. + /// \param repository_paths A set of file-system paths of the repositories. + /// \param startup_models A set of models to be loaded at startup + /// if model control is enabled. + /// \param strict_model_config If false attempt to autofill missing required + /// information in each model configuration. + /// \param polling_enabled If true, then PollAndUpdate() is allowed. + /// Otherwise, it is not allowed. + /// \param model_control_enabled If true, then LoadUnloadModel() is allowed + /// and the models in the model repository will not be loaded at startup. + /// Otherwise, LoadUnloadModel() is not allowed and the models will be loaded. + /// Cannot be set to true if polling_enabled is true. + /// \param life_cycle_options The options to configure ModelLifeCycle. + /// \param model_repository_manager Return the model repository manager. + /// \return The error status. + static Status Create( + InferenceServer* server, const std::string& server_version, + const std::set& repository_paths, + const std::set& startup_models, + const bool strict_model_config, const bool polling_enabled, + const bool model_control_enabled, + const ModelLifeCycleOptions& life_cycle_options, + std::unique_ptr* model_repository_manager); + + /// Poll the model repository to determine the new set of models and + /// compare with the current set. And serve the new set of models based + /// on their version policy. + Status PollAndUpdate(); + + /// Load or unload a specified model. + /// \param models The models and the parameters to be loaded or unloaded + /// \param type The type action to be performed. If the action is LOAD and + /// the model has been loaded, the model will be re-loaded. + /// \return error status. Return "NOT_FOUND" if it tries to load + /// a non-existing model or if it tries to unload a model that hasn't been + /// loaded. + Status LoadUnloadModel( + const std::unordered_map< + std::string, std::vector>& models, + const ActionType type, const bool unload_dependents); + + /// Unload all models. This function should be called before shutting down + /// the model repository manager. + /// \return error status. + Status UnloadAllModels(); + + /// Instruct all models to stop accepting new inference requests. However, + /// the models are still capable of processing inference requests + /// if the model considers them as part of the in-flight inference. + /// \return error status. + Status StopAllModels(); + + /// \return the number of in-flight inferences for the all versions of all + /// models. The set element will be a tuple of . Note that a model version will not be included + /// if it doesn't have in-flight inferences. + const std::set> InflightStatus(); + + /// \param strict_readiness If true, only models that have at least one + /// ready version will be considered as live. Otherwise, the models that + /// have loading / unloading versions will also be live. + /// \return the state of all versions of all live models. + const ModelStateMap LiveModelStates(bool strict_readiness = false); + + /// \return the state of all versions of all models that have every + /// been (attempted) loaded over the lifetime of the server. + const ModelStateMap ModelStates(); + + /// \return the states of all versions of a specific model. + const VersionStateMap VersionStates(const std::string& model_name); + + /// \return the ready-state of a specific model version. + Status ModelState( + const std::string& model_name, const int64_t model_version, + ModelReadyState* state); + + /// Get the index of all models in all repositories. + /// \param ready_only If true return only index of models that are ready. + /// \param index Returns the index. + /// \return error status. + Status RepositoryIndex(const bool ready_only, std::vector* index); + + /// Obtain the specified model. + /// \param model_name The name of the model. + /// \param model_version The version of the model. + /// \param model Return the model object. + /// \return error status. + Status GetModel( + const std::string& model_name, const int64_t model_version, + std::shared_ptr* model); + + // Register model repository path. + /// \param repository Path to model repository. + /// \param model_mapping Mapping with (overridden) model name as key, subdir + /// name as value. + /// \return error status + Status RegisterModelRepository( + const std::string& repository, + const std::unordered_map& model_mapping); + + // Unregister model repository path. + /// \param repository Path to model repository. + /// \return error status + Status UnregisterModelRepository(const std::string& repository); + + private: + struct ModelInfo; + + // Map from model name to information about the model. + using ModelInfoMap = + std::unordered_map>; + + // Set of DependencyNode + using NodeSet = std::set; + + ModelRepositoryManager( + const std::set& repository_paths, const bool autofill, + const bool polling_enabled, const bool model_control_enabled, + const double min_compute_capability, + std::unique_ptr life_cycle); + + /// The internal function that are called in Create() and PollAndUpdate(). + Status PollAndUpdateInternal(bool* all_models_polled); + + /// The internal function that load or unload a set of models. + Status LoadUnloadModels( + const std::unordered_map< + std::string, std::vector>& models, + const ActionType type, const bool unload_dependents, + bool* all_models_polled); + + /// Poll the requested models in the model repository and + /// compare with the current set. Return the additions, deletions, + /// and modifications that have occurred. This function will not updated + /// the current model info, it is caller's responsibility to do so. + /// \param models The map from models to be polled to their associated + /// parameters. + /// \param added The names of the models added to the repository. + /// \param deleted The names of the models removed from the repository. + /// \param modified The names of the models remaining in the + /// repository that have been changed. + /// \param unmodified The names of the models remaining in the + /// repository that have not changed. + /// \param updated_infos The model infos retrieved from the poll. + /// \param all_models_polled Return true if all models are polled and + /// their model configuration are validated successfully. Instead of aborting + /// the polling, the models that fail will be ignored and their model infos + /// will stay in the previous state. + /// \return The error status. + Status Poll( + const std::unordered_map< + std::string, std::vector>& models, + std::set* added, std::set* deleted, + std::set* modified, std::set* unmodified, + ModelInfoMap* updated_infos, bool* all_models_polled); + + /// Helper function for Poll() to initialize ModelInfo for the model. + /// \param name The name of the model. + /// \param path The model path. Empty path means the model is provided via + /// 'params' + /// \param params The model parameters provided for polling model. + /// \param info Return the updated ModelInfo. 'nullptr' will be returned if + /// existing ModelInfo for the model should be reused. + /// \return The error status. + Status InitializeModelInfo( + const std::string& name, const std::string& path, + const std::vector& params, + std::unique_ptr* info); + + /// Load models based on the dependency graph. The function will iteratively + /// load models that all the models they depend on has been loaded, and unload + /// models if their dependencies are no longer satisfied. + /// \return The status of the model loads. + std::map LoadModelByDependency(); + + /// Helper function to update the dependency graph based on the poll result + /// \param added The names of the models added to the repository. + /// \param deleted The names of the models removed from the repository. + /// \param modified The names of the models remaining in the + /// repository that have been changed. + /// \param deleted_dependents The names of dependent models to be removed + /// from the repository. + /// \return The error status. + Status UpdateDependencyGraph( + const std::set& added, const std::set& deleted, + const std::set& modified, + std::set* deleted_dependents = nullptr); + + /// Helper function to uncheck the nodes because the model that they depends + /// on has changed. The unchecked nodes will be validated again. + /// The function will be call recursively to uncheck all downstreams. + /// \param downstreams The nodes to be unchecked. + /// \param updated_nodes Return the nodes that have been unchecked + void UncheckDownstream(NodeSet* downstreams, NodeSet* updated_nodes); + + /// Helper function to construct the edges between nodes in dependency graph. + /// \param updated_node The node that is newly added or modified. + /// \return True if the node represents an ensemble model. False otherwise. + bool ConnectDependencyGraph(DependencyNode* updated_node); + + /// Get the model info for a named model. + /// \param name The model name. + /// \param model_info Returns the model information. + /// \return OK if found, NOT_FOUND otherwise. + Status GetModelInfo(const std::string& name, ModelInfo** model_info); + + /// Get the models to be loaded / unloaded based on the model loaded in + /// previous iteration. + /// \param loaded_models The models loaded / unloaded in previous iteration. + /// Unloaded models will be represented as models with no loaded versions. + /// \return A pair of node set containing models to be loaded and models to be + /// unloaded for the next iteration. + std::pair ModelsToLoadUnload(const NodeSet& loaded_models); + + /// Check if the node is ready for the next iteration. A node is ready if the + /// node is invalid (containing invalid model config or its depdencies failed + /// to load) or all of its dependencies are satisfied. + /// \param node The node to be checked. + /// \return True if the node is ready. False otherwise. + bool CheckNode(DependencyNode* node); + + Status CircularcyCheck( + DependencyNode* current_node, const DependencyNode* start_node); + + bool ModelDirectoryOverride( + const std::vector& model_params); + + std::set repository_paths_; + const bool autofill_; + const bool polling_enabled_; + const bool model_control_enabled_; + const double min_compute_capability_; + + std::mutex poll_mu_; + ModelInfoMap infos_; + + std::unordered_map> + dependency_graph_; + std::unordered_map> + missing_nodes_; + + // Mappings from (overridden) model names to a pair of their repository and + // absolute path + std::unordered_map> + model_mappings_; + + std::unique_ptr model_life_cycle_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/numa_utils.cc b/3rdparty/core-r22.12/src/numa_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..03f7af278990dc17c8a6746ad693ebfe46bf2b48 --- /dev/null +++ b/3rdparty/core-r22.12/src/numa_utils.cc @@ -0,0 +1,237 @@ +// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#include "numa_utils.h" + +#ifndef _WIN32 +#include +#include +#endif +#include "triton/common/logging.h" + +namespace triton { namespace core { + +namespace { +std::string +VectorToString(const std::vector& vec) +{ + std::string str("["); + for (const auto& element : vec) { + str += std::to_string(element); + str += ","; + } + + str += "]"; + return str; +} + +Status +ParseIntOption(const std::string& msg, const std::string& arg, int* value) +{ + try { + *value = std::stoi(arg); + } + catch (const std::invalid_argument& ia) { + return Status( + Status::Code::INVALID_ARG, + msg + ": Can't parse '" + arg + "' to integer"); + } + return Status::Success; +} + +} // namespace + +// NUMA setting will be ignored on Windows platform +#ifdef _WIN32 +Status +SetNumaConfigOnThread( + const triton::common::HostPolicyCmdlineConfig& host_policy) +{ + return Status::Success; +} + +Status +SetNumaMemoryPolicy(const triton::common::HostPolicyCmdlineConfig& host_policy) +{ + return Status::Success; +} + +Status +GetNumaMemoryPolicyNodeMask(unsigned long* node_mask) +{ + *node_mask = 0; + return Status::Success; +} + +Status +ResetNumaMemoryPolicy() +{ + return Status::Success; +} + +Status +SetNumaThreadAffinity( + std::thread::native_handle_type thread, + const triton::common::HostPolicyCmdlineConfig& host_policy) +{ + return Status::Success; +} +#else +// Use variable to make sure no NUMA related function is actually called +// if Triton is not running with NUMA awareness. i.e. Extra docker permission +// is needed to call the NUMA functions and this ensures backward compatibility. +thread_local bool numa_set = false; + +Status +SetNumaConfigOnThread( + const triton::common::HostPolicyCmdlineConfig& host_policy) +{ + // Set thread affinity + RETURN_IF_ERROR(SetNumaThreadAffinity(pthread_self(), host_policy)); + + // Set memory policy + RETURN_IF_ERROR(SetNumaMemoryPolicy(host_policy)); + + return Status::Success; +} + +Status +SetNumaMemoryPolicy(const triton::common::HostPolicyCmdlineConfig& host_policy) +{ + const auto it = host_policy.find("numa-node"); + if (it != host_policy.end()) { + int node_id; + RETURN_IF_ERROR( + ParseIntOption("Parsing 'numa-node' value", it->second, &node_id)); + LOG_VERBOSE(1) << "Thread is binding to NUMA node " << it->second + << ". Max NUMA node count: " << (numa_max_node() + 1); + numa_set = true; + unsigned long node_mask = 1UL << node_id; + if (set_mempolicy(MPOL_BIND, &node_mask, (numa_max_node() + 1) + 1) != 0) { + return Status( + Status::Code::INTERNAL, + std::string("Unable to set NUMA memory policy: ") + strerror(errno)); + } + } + return Status::Success; +} + +Status +GetNumaMemoryPolicyNodeMask(unsigned long* node_mask) +{ + *node_mask = 0; + int mode; + if (numa_set && + get_mempolicy(&mode, node_mask, numa_max_node() + 1, NULL, 0) != 0) { + return Status( + Status::Code::INTERNAL, + std::string("Unable to get NUMA node for current thread: ") + + strerror(errno)); + } + return Status::Success; +} + +Status +ResetNumaMemoryPolicy() +{ + if (numa_set && (set_mempolicy(MPOL_DEFAULT, nullptr, 0) != 0)) { + return Status( + Status::Code::INTERNAL, + std::string("Unable to reset NUMA memory policy: ") + strerror(errno)); + } + numa_set = false; + return Status::Success; +} + +Status +SetNumaThreadAffinity( + std::thread::native_handle_type thread, + const triton::common::HostPolicyCmdlineConfig& host_policy) +{ + const auto it = host_policy.find("cpu-cores"); + if (it != host_policy.end()) { + // Parse CPUs + std::vector cpus; + { + const auto& cpu_str = it->second; + auto delim_cpus = cpu_str.find(","); + int current_pos = 0; + while (true) { + auto delim_range = cpu_str.find("-", current_pos); + if (delim_range == std::string::npos) { + return Status( + Status::Code::INVALID_ARG, + std::string("host policy setting 'cpu-cores' format is " + "'-'. Got ") + + cpu_str.substr( + current_pos, ((delim_cpus == std::string::npos) + ? (cpu_str.length() + 1) + : delim_cpus) - + current_pos)); + } + int lower, upper; + RETURN_IF_ERROR(ParseIntOption( + "Parsing 'cpu-cores' value", + cpu_str.substr(current_pos, delim_range - current_pos), &lower)); + RETURN_IF_ERROR(ParseIntOption( + "Parsing 'cpu-cores' value", + (delim_cpus == std::string::npos) + ? cpu_str.substr(delim_range + 1) + : cpu_str.substr( + delim_range + 1, delim_cpus - (delim_range + 1)), + &upper)); + for (; lower <= upper; ++lower) { + cpus.push_back(lower); + } + // break if the processed range is the last specified range + if (delim_cpus != std::string::npos) { + current_pos = delim_cpus + 1; + delim_cpus = cpu_str.find(",", current_pos); + } else { + break; + } + } + } + + LOG_VERBOSE(1) << "Thread is binding to one of the CPUs: " + << VectorToString(cpus); + numa_set = true; + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + for (int cpu : cpus) { + CPU_SET(cpu, &cpuset); + } + if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset) != 0) { + return Status( + Status::Code::INTERNAL, + std::string("Unable to set NUMA thread affinity: ") + + strerror(errno)); + } + } + return Status::Success; +} +#endif + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/numa_utils.h b/3rdparty/core-r22.12/src/numa_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..bb226bdfc23b31f2623c838e5f4cafbd4d18c914 --- /dev/null +++ b/3rdparty/core-r22.12/src/numa_utils.h @@ -0,0 +1,57 @@ +// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include "status.h" +#include "triton/common/model_config.h" +#include "tritonserver_apis.h" + +namespace triton { namespace core { + +// Helper function to set memory policy and thread affinity on current thread +Status SetNumaConfigOnThread( + const triton::common::HostPolicyCmdlineConfig& host_policy); + +// Restrict the memory allocation to specific NUMA node. +Status SetNumaMemoryPolicy( + const triton::common::HostPolicyCmdlineConfig& host_policy); + +// Retrieve the node mask used to set memory policy for the current thread +Status GetNumaMemoryPolicyNodeMask(unsigned long* node_mask); + +// Reset the memory allocation setting. +Status ResetNumaMemoryPolicy(); + +// Set a thread affinity to be on specific cpus. +Status SetNumaThreadAffinity( + std::thread::native_handle_type thread, + const triton::common::HostPolicyCmdlineConfig& host_policy); + + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/payload.cc b/3rdparty/core-r22.12/src/payload.cc new file mode 100644 index 0000000000000000000000000000000000000000..c5c2fa26b408eca5bf4ac3b75003171b792f47e7 --- /dev/null +++ b/3rdparty/core-r22.12/src/payload.cc @@ -0,0 +1,215 @@ +// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "payload.h" + +namespace triton { namespace core { + +Payload::Payload() + : op_type_(Operation::INFER_RUN), + requests_(std::vector>()), + OnCallback_([]() {}), instance_(nullptr), state_(State::UNINITIALIZED), + batcher_start_ns_(0), saturated_(false) +{ + exec_mu_.reset(new std::mutex()); +} + +const Status& +Payload::MergePayload(std::shared_ptr& payload) +{ + if ((payload->GetOpType() != Operation::INFER_RUN) || + (op_type_ != Operation::INFER_RUN)) { + static Status op_type_error( + Status::Code::INTERNAL, + "Attempted to merge payloads of type that are not INFER_RUN"); + return op_type_error; + } + if (payload->GetInstance() != instance_) { + static Status instance_error( + Status::Code::INTERNAL, + "Attempted to merge payloads of mismatching instance"); + return instance_error; + } + if ((payload->GetState() != State::EXECUTING) || + (state_ != State::EXECUTING)) { + static Status state_error( + Status::Code::INTERNAL, + "Attempted to merge payloads that are not in executing state"); + return state_error; + } + + // Skip comparison if not initialized (required), here assume either all + // payloads are initialized or otherwise. + if (required_equal_inputs_.Initialized() && + !required_equal_inputs_.HasEqualInputs(*payload->Requests().begin())) { + static Status shape_error( + Status::Code::INVALID_ARG, + "Attempted to merge payloads that has non-equal inputs"); + return shape_error; + } + + requests_.insert( + requests_.end(), std::make_move_iterator(payload->Requests().begin()), + std::make_move_iterator(payload->Requests().end())); + + payload->Callback(); + + return Status::Success; +} + +void +Payload::Reset(const Operation op_type, TritonModelInstance* instance) +{ + op_type_ = op_type; + requests_.clear(); + OnCallback_ = []() {}; + release_callbacks_.clear(); + instance_ = instance; + state_ = State::UNINITIALIZED; + status_.reset(new std::promise()); + required_equal_inputs_ = RequiredEqualInputs(); + batcher_start_ns_ = 0; + saturated_ = false; +} + +void +Payload::Release() +{ + op_type_ = Operation::INFER_RUN; + requests_.clear(); + OnCallback_ = []() {}; + release_callbacks_.clear(); + instance_ = nullptr; + state_ = State::RELEASED; + required_equal_inputs_ = RequiredEqualInputs(); + batcher_start_ns_ = 0; + saturated_ = false; +} + +size_t +Payload::BatchSize() +{ + size_t batch_size = 0; + for (const auto& request : requests_) { + batch_size += std::max(1U, request->BatchSize()); + } + return batch_size; +} + +void +Payload::ReserveRequests(size_t size) +{ + requests_.reserve(size); +} + +void +Payload::AddRequest(std::unique_ptr request) +{ + if ((batcher_start_ns_ == 0) || + (batcher_start_ns_ > request->BatcherStartNs())) { + batcher_start_ns_ = request->BatcherStartNs(); + } + requests_.push_back(std::move(request)); +} + +void +Payload::SetCallback(std::function OnCallback) +{ + OnCallback_ = OnCallback; +} + +void +Payload::SetInstance(TritonModelInstance* model_instance) +{ + instance_ = model_instance; +} + +void +Payload::AddInternalReleaseCallback(std::function&& callback) +{ + release_callbacks_.emplace_back(std::move(callback)); +} + +void +Payload::MarkSaturated() +{ + saturated_ = true; +} + +void +Payload::SetState(Payload::State state) +{ + state_ = state; +} + +Status +Payload::Wait() +{ + return status_->get_future().get(); +} + +void +Payload::Callback() +{ + OnCallback_(); +} + +void +Payload::OnRelease() +{ + // Invoke the release callbacks added internally before releasing the + // request to user provided callback. + for (auto it = release_callbacks_.rbegin(); it != release_callbacks_.rend(); + it++) { + (*it)(); + } + release_callbacks_.clear(); +} + +void +Payload::Execute(bool* should_exit) +{ + *should_exit = false; + + Status status; + switch (op_type_) { + case Operation::INFER_RUN: + instance_->Schedule(std::move(requests_), OnCallback_); + break; + case Operation::INIT: + status = instance_->Initialize(); + break; + case Operation::WARM_UP: + status = instance_->WarmUp(); + break; + case Operation::EXIT: + *should_exit = true; + } + + status_->set_value(status); +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/payload.h b/3rdparty/core-r22.12/src/payload.h new file mode 100644 index 0000000000000000000000000000000000000000..1650917ae20053b457dad06bd15c5dd2a965b6d5 --- /dev/null +++ b/3rdparty/core-r22.12/src/payload.h @@ -0,0 +1,102 @@ +// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "backend_model_instance.h" +#include "infer_request.h" +#include "scheduler_utils.h" +#include "status.h" + +namespace triton { namespace core { + +class Payload { + public: + enum Operation { INFER_RUN = 0, INIT = 1, WARM_UP = 2, EXIT = 3 }; + enum State { + UNINITIALIZED = 0, + READY = 1, + REQUESTED = 2, + SCHEDULED = 3, + EXECUTING = 4, + RELEASED = 5 + }; + + Payload(); + void Reset(const Operation op_type, TritonModelInstance* instance = nullptr); + const Status& MergePayload(std::shared_ptr& payload); + Operation GetOpType() { return op_type_; } + std::mutex* GetExecMutex() { return exec_mu_.get(); } + size_t RequestCount() { return requests_.size(); } + size_t BatchSize(); + void ReserveRequests(size_t size); + void AddRequest(std::unique_ptr request); + std::vector>& Requests() + { + return requests_; + } + uint64_t BatcherStartNs() { return batcher_start_ns_; } + void SetCallback(std::function OnCallback); + void Callback(); + void AddInternalReleaseCallback(std::function&& callback); + void OnRelease(); + void SetInstance(TritonModelInstance* model_instance); + TritonModelInstance* GetInstance() { return instance_; } + void MarkSaturated(); + bool IsSaturated() { return saturated_; } + RequiredEqualInputs* MutableRequiredEqualInputs() + { + return &required_equal_inputs_; + } + + State GetState() { return state_; } + void SetState(State state); + void Execute(bool* should_exit); + Status Wait(); + void Release(); + + private: + Operation op_type_; + std::vector> requests_; + std::function OnCallback_; + std::vector> release_callbacks_; + TritonModelInstance* instance_; + State state_; + std::unique_ptr> status_; + std::unique_ptr exec_mu_; + uint64_t batcher_start_ns_; + RequiredEqualInputs required_equal_inputs_; + + bool saturated_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/pinned_memory_manager.cc b/3rdparty/core-r22.12/src/pinned_memory_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b4ffd42207ccf22a6f8da30f7785832f378a26b --- /dev/null +++ b/3rdparty/core-r22.12/src/pinned_memory_manager.cc @@ -0,0 +1,378 @@ +// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// + +#include "pinned_memory_manager.h" + +#include +#include "numa_utils.h" +#include "triton/common/logging.h" + +#ifdef TRITON_ENABLE_GPU +#include +#endif // TRITON_ENABLE_GPU + +namespace triton { namespace core { + +namespace { + +std::string +PointerToString(void* ptr) +{ + std::stringstream ss; + ss << ptr; + return ss.str(); +} + +Status +ParseIntOption(const std::string& msg, const std::string& arg, int* value) +{ + try { + *value = std::stoi(arg); + } + catch (const std::invalid_argument& ia) { + return Status( + Status::Code::INVALID_ARG, + msg + ": Can't parse '" + arg + "' to integer"); + } + return Status::Success; +} + +} // namespace + +std::unique_ptr PinnedMemoryManager::instance_; +uint64_t PinnedMemoryManager::pinned_memory_byte_size_; + +PinnedMemoryManager::PinnedMemory::PinnedMemory( + void* pinned_memory_buffer, uint64_t size) + : pinned_memory_buffer_(pinned_memory_buffer) +{ + if (pinned_memory_buffer_ != nullptr) { + managed_pinned_memory_ = boost::interprocess::managed_external_buffer( + boost::interprocess::create_only_t{}, pinned_memory_buffer_, size); + } +} + + +PinnedMemoryManager::PinnedMemory::~PinnedMemory() +{ +#ifdef TRITON_ENABLE_GPU + if (pinned_memory_buffer_ != nullptr) { + cudaFreeHost(pinned_memory_buffer_); + } +#endif // TRITON_ENABLE_GPU +} + +PinnedMemoryManager::~PinnedMemoryManager() +{ + // Clean up + for (const auto& memory_info : memory_info_) { + const auto& is_pinned = memory_info.second.first; + if (!is_pinned) { + free(memory_info.first); + } + } +} + +void +PinnedMemoryManager::AddPinnedMemoryBuffer( + const std::shared_ptr& pinned_memory_buffer, + unsigned long node_mask) +{ + pinned_memory_buffers_[node_mask] = pinned_memory_buffer; +} + +Status +PinnedMemoryManager::AllocInternal( + void** ptr, uint64_t size, TRITONSERVER_MemoryType* allocated_type, + bool allow_nonpinned_fallback, PinnedMemory* pinned_memory_buffer) +{ + auto status = Status::Success; + if (pinned_memory_buffer->pinned_memory_buffer_ != nullptr) { + std::lock_guard lk(pinned_memory_buffer->buffer_mtx_); + *ptr = pinned_memory_buffer->managed_pinned_memory_.allocate( + size, std::nothrow_t{}); + *allocated_type = TRITONSERVER_MEMORY_CPU_PINNED; + if (*ptr == nullptr) { + status = Status( + Status::Code::INTERNAL, "failed to allocate pinned system memory"); + } + } else { + status = Status( + Status::Code::INTERNAL, + "failed to allocate pinned system memory: no pinned memory pool"); + } + + bool is_pinned = true; + if ((!status.IsOk()) && allow_nonpinned_fallback) { + static bool warning_logged = false; + if (!warning_logged) { + LOG_WARNING << status.Message() + << ", falling back to non-pinned system memory"; + warning_logged = true; + } + *ptr = malloc(size); + *allocated_type = TRITONSERVER_MEMORY_CPU; + is_pinned = false; + if (*ptr == nullptr) { + status = Status( + Status::Code::INTERNAL, + "failed to allocate non-pinned system memory"); + } else { + status = Status::Success; + } + } + + // keep track of allocated buffer or clean up + { + std::lock_guard lk(info_mtx_); + if (status.IsOk()) { + auto res = memory_info_.emplace( + *ptr, std::make_pair(is_pinned, pinned_memory_buffer)); + if (!res.second) { + status = Status( + Status::Code::INTERNAL, "unexpected memory address collision, '" + + PointerToString(*ptr) + + "' has been managed"); + } + LOG_VERBOSE(1) << (is_pinned ? "" : "non-") + << "pinned memory allocation: " + << "size " << size << ", addr " << *ptr; + } + } + + if ((!status.IsOk()) && (*ptr != nullptr)) { + if (is_pinned) { + std::lock_guard lk(pinned_memory_buffer->buffer_mtx_); + pinned_memory_buffer->managed_pinned_memory_.deallocate(*ptr); + } else { + free(*ptr); + } + } + + return status; +} + +Status +PinnedMemoryManager::FreeInternal(void* ptr) +{ + bool is_pinned = true; + PinnedMemory* pinned_memory_buffer = nullptr; + { + std::lock_guard lk(info_mtx_); + auto it = memory_info_.find(ptr); + if (it != memory_info_.end()) { + is_pinned = it->second.first; + pinned_memory_buffer = it->second.second; + LOG_VERBOSE(1) << (is_pinned ? "" : "non-") + << "pinned memory deallocation: " + << "addr " << ptr; + memory_info_.erase(it); + } else { + return Status( + Status::Code::INTERNAL, "unexpected memory address '" + + PointerToString(ptr) + + "' is not being managed"); + } + } + + if (is_pinned) { + std::lock_guard lk(pinned_memory_buffer->buffer_mtx_); + pinned_memory_buffer->managed_pinned_memory_.deallocate(ptr); + } else { + free(ptr); + } + return Status::Success; +} + +void +PinnedMemoryManager::Reset() +{ + instance_.reset(); +} + +Status +PinnedMemoryManager::Create(const Options& options) +{ + if (instance_ != nullptr) { + LOG_WARNING << "New pinned memory pool of size " + << options.pinned_memory_pool_byte_size_ + << " could not be created since one already exists" + << " of size " << pinned_memory_byte_size_; + return Status::Success; + } + + instance_.reset(new PinnedMemoryManager()); + if (options.host_policy_map_.empty()) { + void* buffer = nullptr; +#ifdef TRITON_ENABLE_GPU + auto err = cudaHostAlloc( + &buffer, options.pinned_memory_pool_byte_size_, cudaHostAllocPortable); + if (err != cudaSuccess) { + buffer = nullptr; + LOG_WARNING << "Unable to allocate pinned system memory, pinned memory " + "pool will not be available: " + << std::string(cudaGetErrorString(err)); + } else if (options.pinned_memory_pool_byte_size_ != 0) { + LOG_INFO << "Pinned memory pool is created at '" + << PointerToString(buffer) << "' with size " + << options.pinned_memory_pool_byte_size_; + } else { + LOG_INFO << "Pinned memory pool disabled"; + } +#endif // TRITON_ENABLE_GPU + try { + instance_->AddPinnedMemoryBuffer( + std::shared_ptr( + new PinnedMemory(buffer, options.pinned_memory_pool_byte_size_)), + 0); + } + catch (const std::exception& ex) { + return Status( + Status::Code::INTERNAL, + "Failed to add Pinned Memory buffer: " + std::string(ex.what())); + } + } else { + // Create only one buffer / manager should be created for one node, + // and all associated devices should request memory from the shared manager + std::map numa_map; + for (const auto host_policy : options.host_policy_map_) { + const auto numa_it = host_policy.second.find("numa-node"); + if (numa_it != host_policy.second.end()) { + int32_t numa_id; + if (ParseIntOption("Parsing NUMA node", numa_it->second, &numa_id) + .IsOk()) { + numa_map.emplace(numa_id, host_policy.first); + } + } + } + for (const auto node_policy : numa_map) { + auto status = + SetNumaMemoryPolicy(options.host_policy_map_.at(node_policy.second)); + if (!status.IsOk()) { + LOG_WARNING << "Unable to allocate pinned system memory for NUMA node " + << node_policy.first << ": " << status.AsString(); + continue; + } + unsigned long node_mask; + status = GetNumaMemoryPolicyNodeMask(&node_mask); + if (!status.IsOk()) { + LOG_WARNING << "Unable to get NUMA node set for current thread: " + << status.AsString(); + continue; + } + void* buffer = nullptr; +#ifdef TRITON_ENABLE_GPU + auto err = cudaHostAlloc( + &buffer, options.pinned_memory_pool_byte_size_, + cudaHostAllocPortable); + if (err != cudaSuccess) { + buffer = nullptr; + LOG_WARNING << "Unable to allocate pinned system memory, pinned memory " + "pool will not be available: " + << std::string(cudaGetErrorString(err)); + } else if (options.pinned_memory_pool_byte_size_ != 0) { + LOG_INFO << "Pinned memory pool is created at '" + << PointerToString(buffer) << "' with size " + << options.pinned_memory_pool_byte_size_; + } else { + LOG_INFO << "Pinned memory pool disabled"; + } +#endif // TRITON_ENABLE_GPU + ResetNumaMemoryPolicy(); + try { + instance_->AddPinnedMemoryBuffer( + std::shared_ptr(new PinnedMemory( + buffer, options.pinned_memory_pool_byte_size_)), + node_mask); + } + catch (const std::exception& ex) { + return Status( + Status::Code::INTERNAL, + "Failed to add Pinned Memory buffer with host policy: " + + std::string(ex.what())); + } + } + // If no pinned memory is allocated, add an empty entry where all allocation + // will be on normal system memory + if (instance_->pinned_memory_buffers_.empty()) { + try { + instance_->AddPinnedMemoryBuffer( + std::shared_ptr(new PinnedMemory( + nullptr, options.pinned_memory_pool_byte_size_)), + 0); + } + catch (const std::exception& ex) { + return Status( + Status::Code::INTERNAL, + "Failed to add empty Pinned Memory entry: " + + std::string(ex.what())); + } + } + } + pinned_memory_byte_size_ = options.pinned_memory_pool_byte_size_; + return Status::Success; +} + +Status +PinnedMemoryManager::Alloc( + void** ptr, uint64_t size, TRITONSERVER_MemoryType* allocated_type, + bool allow_nonpinned_fallback) +{ + if (instance_ == nullptr) { + return Status( + Status::Code::UNAVAILABLE, "PinnedMemoryManager has not been created"); + } + + auto pinned_memory_buffer = + instance_->pinned_memory_buffers_.begin()->second.get(); + if (instance_->pinned_memory_buffers_.size() > 1) { + unsigned long node_mask; + if (GetNumaMemoryPolicyNodeMask(&node_mask).IsOk()) { + auto it = instance_->pinned_memory_buffers_.find(node_mask); + if (it != instance_->pinned_memory_buffers_.end()) { + pinned_memory_buffer = it->second.get(); + } + } + } + + return instance_->AllocInternal( + ptr, size, allocated_type, allow_nonpinned_fallback, + pinned_memory_buffer); +} + +Status +PinnedMemoryManager::Free(void* ptr) +{ + if (instance_ == nullptr) { + return Status( + Status::Code::UNAVAILABLE, "PinnedMemoryManager has not been created"); + } + + return instance_->FreeInternal(ptr); +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/pinned_memory_manager.h b/3rdparty/core-r22.12/src/pinned_memory_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..1236f06b4f73d41c4bdba9e0a9ff10118a106344 --- /dev/null +++ b/3rdparty/core-r22.12/src/pinned_memory_manager.h @@ -0,0 +1,108 @@ +// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +#pragma once + +#include +#include +#include +#include +#include "status.h" +#include "triton/common/model_config.h" + +namespace triton { namespace core { + +// This is a singleton class responsible for maintaining pinned memory pool +// used by the inference server. Pinned memory allocations and deallocations +// must be requested via functions provided by this class. +class PinnedMemoryManager { + public: + // Options to configure pinned memeory manager. + struct Options { + Options( + uint64_t b = 0, + const triton::common::HostPolicyCmdlineConfigMap& host_policy_map = {}) + : pinned_memory_pool_byte_size_(b), host_policy_map_(host_policy_map) + { + } + + uint64_t pinned_memory_pool_byte_size_; + triton::common::HostPolicyCmdlineConfigMap host_policy_map_; + }; + + ~PinnedMemoryManager(); + + // Create the pinned memory manager based on 'options' specified. + // Return Status object indicating success or failure. + static Status Create(const Options& options); + + // Allocate pinned memory with the requested 'size' and return the pointer + // in 'ptr'. If 'allow_nonpinned_fallback' is true, regular system memory + // will be allocated as fallback in the case where pinned memory fails to + // be allocated. + // Return Status object indicating success or failure. + static Status Alloc( + void** ptr, uint64_t size, TRITONSERVER_MemoryType* allocated_type, + bool allow_nonpinned_fallback); + + // Free the memory allocated by the pinned memory manager. + // Return Status object indicating success or failure. + static Status Free(void* ptr); + + protected: + // Provide explicit control on the lifecycle of the CUDA memory manager, + // for testing only. + static void Reset(); + + private: + class PinnedMemory { + public: + PinnedMemory(void* pinned_memory_buffer, uint64_t size); + ~PinnedMemory(); + void* pinned_memory_buffer_; + std::mutex buffer_mtx_; + boost::interprocess::managed_external_buffer managed_pinned_memory_; + }; + + PinnedMemoryManager() = default; + + Status AllocInternal( + void** ptr, uint64_t size, TRITONSERVER_MemoryType* allocated_type, + bool allow_nonpinned_fallback, PinnedMemory* pinned_memory_buffer); + Status FreeInternal(void* ptr); + void AddPinnedMemoryBuffer( + const std::shared_ptr& pinned_memory_buffer, + unsigned long node_mask); + + static std::unique_ptr instance_; + static uint64_t pinned_memory_byte_size_; + + std::mutex info_mtx_; + std::map> memory_info_; + std::map> pinned_memory_buffers_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/rate_limiter.cc b/3rdparty/core-r22.12/src/rate_limiter.cc new file mode 100644 index 0000000000000000000000000000000000000000..8052281332f061d1e06476b31e83dac17965019a --- /dev/null +++ b/3rdparty/core-r22.12/src/rate_limiter.cc @@ -0,0 +1,943 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "rate_limiter.h" + +#include +#include "triton/common/logging.h" + +namespace triton { namespace core { + +constexpr size_t MAX_PAYLOAD_BUCKET_COUNT = 1000; + +//========================================================================= +// Core Implementation +//========================================================================= + +Status +RateLimiter::Create( + const bool ignore_resources_and_priority, + const RateLimiter::ResourceMap& resource_map, + std::unique_ptr* rate_limiter) +{ + std::unique_ptr local_rate_limiter( + new RateLimiter(ignore_resources_and_priority, resource_map)); + *rate_limiter = std::move(local_rate_limiter); + + return Status::Success; +} + +Status +RateLimiter::RegisterModelInstance( + TritonModelInstance* triton_model_instance, + const RateLimiterConfig& rate_limiter_config) +{ + { + std::lock_guard lk1(model_ctx_mtx_); + std::lock_guard lk2(model_instance_ctx_mtx_); + + auto& model_context = model_contexts_[triton_model_instance->Model()]; + auto& model_instances = + model_instance_ctxs_[triton_model_instance->Model()]; + + model_instances.push_back( + std::shared_ptr(new ModelInstanceContext( + triton_model_instance, &model_context, rate_limiter_config, + [this](ModelInstanceContext* instance) { OnStage(instance); }, + [this](ModelInstanceContext* instance) { OnRelease(instance); }))); + model_context.AddAvailableInstance(model_instances.back().get()); + model_context.AddSpecificRequestQueue(); + + if (!ignore_resources_and_priority_) { + resource_manager_->AddModelInstance(model_instances.back().get()); + RETURN_IF_ERROR(resource_manager_->UpdateResourceLimits()); + } + } + + InitializePayloadQueues(triton_model_instance); + + return Status::Success; +} + +Status +RateLimiter::UnregisterModel(const TritonModel* model) +{ + { + std::lock_guard lk1(model_ctx_mtx_); + std::lock_guard lk2(model_instance_ctx_mtx_); + + auto& model_context = model_contexts_[model]; + + model_context.RequestRemoval(); + for (const auto& instance : model_instance_ctxs_[model]) { + instance->WaitForRemoval(); + if (!ignore_resources_and_priority_) { + resource_manager_->RemoveModelInstance(instance.get()); + } + } + + model_instance_ctxs_.erase(model); + model_contexts_.erase(model); + } + + if (!ignore_resources_and_priority_) { + RETURN_IF_ERROR(resource_manager_->UpdateResourceLimits()); + } + + { + std::lock_guard lk(payload_queues_mu_); + if (payload_queues_.find(model) != payload_queues_.end()) { + payload_queues_.erase(model); + } + } + + return Status::Success; +} + +bool +RateLimiter::PayloadSlotAvailable(const TritonModel* model) +{ + bool result; + PayloadQueue* payload_queue = payload_queues_[model].get(); + { + std::lock_guard lk(payload_queue->mu_); + result = payload_queue->queue_->Size() < + 2 * payload_queue->specific_queues_.size(); + } + return result; +} + +Status +RateLimiter::EnqueuePayload( + const TritonModel* model, std::shared_ptr payload) +{ + auto pinstance = payload->GetInstance(); + if (payload_queues_.find(model) == payload_queues_.end()) { + LOG_INFO << "Should not print this "; + } + PayloadQueue* payload_queue = payload_queues_[model].get(); + { + std::lock_guard lk(payload_queue->mu_); + payload->SetState(Payload::State::REQUESTED); + if (ignore_resources_and_priority_) { + SchedulePayload(pinstance, payload_queue, payload); + } + } + if (ignore_resources_and_priority_) { + if (pinstance == nullptr) { + payload_queue->cv_.notify_one(); + } else { + payload_queue->cv_.notify_all(); + } + } else { + StandardScheduleFunc sched_func = [this, payload_queue, + payload](ModelInstanceContext* mi) { + { + std::lock_guard lk(payload_queue->mu_); + this->SchedulePayload(mi->RawInstance(), payload_queue, payload); + } + auto cb = [mi]() { mi->Release(); }; + payload->AddInternalReleaseCallback(cb); + if (mi->RawInstance() == nullptr) { + payload_queue->cv_.notify_one(); + } else { + payload_queue->cv_.notify_all(); + } + }; + DeferPayloadSchedule(sched_func, model, payload->GetInstance()); + } + return Status::Success; +} + +void +RateLimiter::DequeuePayload( + std::deque& instances, + std::shared_ptr* payload) +{ + payload->reset(); + if (payload_queues_.find(instances[0]->Model()) == payload_queues_.end()) { + LOG_INFO << "Should not print this "; + } + PayloadQueue* payload_queue = payload_queues_[instances[0]->Model()].get(); + std::vector> merged_payloads; + size_t instance_index = std::numeric_limits::max(); + { + std::unique_lock lk(payload_queue->mu_); + payload_queue->cv_.wait(lk, [&instances, &instance_index, payload_queue]() { + bool empty = payload_queue->queue_->Empty(); + if (empty) { + instance_index = 0; + for (const auto instance : instances) { + empty = payload_queue->specific_queues_[instance]->Empty(); + if (empty) { + instance_index++; + } else { + break; + } + } + } + return !empty; + }); + if (instance_index < instances.size()) { + TritonModelInstance* instance = instances[instance_index]; + if (!payload_queue->specific_queues_[instance]->Empty()) { + payload_queue->specific_queues_[instance]->Dequeue( + payload, &merged_payloads); + } + } else { + payload_queue->queue_->Dequeue(payload, &merged_payloads); + } + } + for (auto& merge_payload : merged_payloads) { + PayloadRelease(merge_payload); + } + (*payload)->Callback(); + if ((*payload)->GetInstance() == nullptr) { + (*payload)->SetInstance(instances.front()); + instances.pop_front(); + } else { + instances.erase(instances.begin() + instance_index); + } +} + +std::shared_ptr +RateLimiter::GetPayload( + const Payload::Operation op_type, TritonModelInstance* instance) +{ + std::shared_ptr payload; + + if (max_payload_bucket_count_ > 0) { + std::lock_guard lock(payload_mu_); + + if (!payload_bucket_.empty()) { + payload = payload_bucket_.back(); + payload_bucket_.pop_back(); + } + if (payload.get() == nullptr && (!payloads_in_use_.empty())) { + // Just checking the front of the queue instead the entire queue for + // an available payload to save time. + if (payloads_in_use_.front().use_count() == 1) { + payload = payloads_in_use_.front(); + payloads_in_use_.pop_front(); + } + } + } + + if (payload.get() == nullptr) { + payload.reset(new Payload()); + } + + payload->Reset(op_type, instance); + return payload; +} + +void +RateLimiter::PayloadRelease(std::shared_ptr& payload) +{ + payload->OnRelease(); + if (max_payload_bucket_count_ > 0) { + std::lock_guard lock(payload_mu_); + + if (payloads_in_use_.size() + payload_bucket_.size() < + max_payload_bucket_count_) { + // Release iff the payload shared_ptr is uniquely held. + if (payload.use_count() == 1) { + payload->Release(); + payload_bucket_.push_back(std::move(payload)); + return; + } else { + payloads_in_use_.push_back(std::move(payload)); + } + } + } +} + +RateLimiter::RateLimiter( + const bool ignore_resources_and_priority, const ResourceMap& resource_map) + : ignore_resources_and_priority_(ignore_resources_and_priority), + max_payload_bucket_count_(MAX_PAYLOAD_BUCKET_COUNT) +{ + ResourceManager::Create(resource_map, &resource_manager_); +} + +void +RateLimiter::InitializePayloadQueues(const TritonModelInstance* instance) +{ + auto& config = instance->Model()->Config(); + uint64_t max_queue_delay_microseconds; + if (config.has_sequence_batching()) { + const auto& batcher_config = config.sequence_batching(); + if (batcher_config.has_oldest()) { + max_queue_delay_microseconds = + batcher_config.oldest().max_queue_delay_microseconds(); + } else { + max_queue_delay_microseconds = 0; + } + } else if (config.has_dynamic_batching()) { + max_queue_delay_microseconds = + config.dynamic_batching().max_queue_delay_microseconds(); + } else { + max_queue_delay_microseconds = 0; + } + { + std::lock_guard lk(payload_queues_mu_); + if (payload_queues_.find(instance->Model()) == payload_queues_.end()) { + payload_queues_.emplace( + instance->Model(), + new PayloadQueue( + config.max_batch_size(), max_queue_delay_microseconds * 1000)); + } + } + PayloadQueue* payload_queue = payload_queues_[instance->Model()].get(); + if (payload_queue->specific_queues_.find(instance) == + payload_queue->specific_queues_.end()) { + payload_queue->specific_queues_.emplace( + instance, + new InstanceQueue( + config.max_batch_size(), max_queue_delay_microseconds * 1000)); + } +} + +Status +RateLimiter::DeferPayloadSchedule( + const StandardScheduleFunc& OnSchedule, const TritonModel* model, + TritonModelInstance* triton_model_instance) +{ + std::lock_guard lk(model_ctx_mtx_); + + auto itr = model_contexts_.find(model); + if (itr == model_contexts_.end()) { + return Status( + Status::Code::INTERNAL, + "Requested model is not yet registered with rate limiter"); + } + + if (itr->second.isRemovalInProgress()) { + return Status( + Status::Code::INTERNAL, + "New model requests can not be made to a model that is being " + "removed"); + } + + itr->second.EnqueueModelInstanceRequest(OnSchedule, triton_model_instance); + itr->second.StageInstanceIfAvailable(triton_model_instance); + + return Status::Success; +} + +void +RateLimiter::SchedulePayload( + TritonModelInstance* tmi, PayloadQueue* payload_queue, + const std::shared_ptr& payload) +{ + if (tmi == nullptr) { + payload_queue->queue_->Enqueue(payload); + } else { + payload_queue->specific_queues_[tmi]->Enqueue(payload); + } + payload->SetState(Payload::State::SCHEDULED); +} + +void +RateLimiter::OnStage(ModelInstanceContext* instance) +{ + { + std::lock_guard lk(staged_instances_mtx_); + staged_instances_.push(instance); + } + AttemptAllocation(); +} + +void +RateLimiter::OnRelease(ModelInstanceContext* instance) +{ + auto& model_context = model_contexts_[instance->RawInstance()->Model()]; + model_context.AddAvailableInstance(instance); + resource_manager_->ReleaseResources(instance); + if (model_context.ContainsPendingRequests(instance->RawInstance()->Index())) { + model_context.StageInstanceIfAvailable(instance->RawInstance()); + } + AttemptAllocation(); +} + +void +RateLimiter::AttemptAllocation() +{ + std::lock_guard lk(staged_instances_mtx_); + if (!staged_instances_.empty()) { + ModelInstanceContext* instance = staged_instances_.top(); + if (resource_manager_->AllocateResources(instance)) { + staged_instances_.pop(); + instance->Allocate(); + } + } +} + +//========================================================================= +// ModelContext Implementation +//========================================================================= + +RateLimiter::ModelContext::ModelContext() : removal_in_progress_(false) {} + +Status +RateLimiter::ModelContext::EnqueueModelInstanceRequest( + const StandardScheduleFunc& OnSchedule, + TritonModelInstance* triton_model_instance) +{ + std::lock_guard lk(sched_request_queue_mtx_); + + if (triton_model_instance == nullptr) { + generic_sched_request_queue_.push(OnSchedule); + } else if ( + (uint32_t)triton_model_instance->Index() < + specific_sched_request_queues_.size()) { + specific_sched_request_queues_[triton_model_instance->Index()].push( + OnSchedule); + } else { + return Status( + Status::Code::INTERNAL, + "expected instance index between 0 and " + + std::to_string(specific_sched_request_queues_.size()) + ", got " + + std::to_string(triton_model_instance->Index())); + } + + return Status::Success; +} + +void +RateLimiter::ModelContext::AddAvailableInstance(ModelInstanceContext* instance) +{ + std::lock_guard lk(avbl_instances_mtx_); + avbl_instances_.push(instance); + instance->MarkAvailable(); +} + + +void +RateLimiter::ModelContext::StageInstanceIfAvailable( + TritonModelInstance* req_instance) +{ + std::lock_guard lk1(sched_request_queue_mtx_); + std::lock_guard lk2(avbl_instances_mtx_); + PriorityQueue backup_queue; + + while (!avbl_instances_.empty()) { + ModelInstanceContext* instance = avbl_instances_.top(); + if ((req_instance != nullptr) && + (instance->RawInstance() != req_instance)) { + backup_queue.push(instance); + avbl_instances_.pop(); + continue; + } + if (!specific_sched_request_queues_[instance->RawInstance()->Index()] + .empty()) { + // Prioritize the specific requests for the available model + // instance highest priority. + const StandardScheduleFunc func = + specific_sched_request_queues_[instance->RawInstance()->Index()] + .front(); + specific_sched_request_queues_[instance->RawInstance()->Index()].pop(); + instance->Stage(func); + } else if (!generic_sched_request_queue_.empty()) { + // If request is for generic model instance then use the + // instance with the highest priority. + const StandardScheduleFunc func = generic_sched_request_queue_.front(); + generic_sched_request_queue_.pop(); + instance->Stage(func); + } else { + // If there are requests for a specific model instance then backup + // the model instance and keep searching through the available + // model instances. The prioritization will be taken care of in the + // staging priority queue. + backup_queue.push(instance); + } + avbl_instances_.pop(); + } + // Restore the backup queue + if (!backup_queue.empty()) { + avbl_instances_.swap(backup_queue); + } +} + +void +RateLimiter::ModelContext::AllocateInstanceIfAvailable() +{ + std::lock_guard lk1(sched_request_queue_mtx_); + std::lock_guard lk2(avbl_instances_mtx_); + PriorityQueue backup_queue; + while (!avbl_instances_.empty()) { + ModelInstanceContext* instance = avbl_instances_.top(); + if (!specific_sched_request_queues_[instance->RawInstance()->Index()] + .empty()) { + // Prioritize the specific requests for the available model + // instance highest priority. + const StandardScheduleFunc func = + specific_sched_request_queues_[instance->RawInstance()->Index()] + .front(); + specific_sched_request_queues_[instance->RawInstance()->Index()].pop(); + instance->DirectAllocate(func); + } else if (!generic_sched_request_queue_.empty()) { + // If request is for generic model instance then use the + // instance with the highest priority. + const StandardScheduleFunc func = generic_sched_request_queue_.front(); + generic_sched_request_queue_.pop(); + instance->DirectAllocate(func); + } else { + // If there are requests for a specific model instance then backup + // the model instance and keep searching through the available + // model instances. The prioritization will be taken care of in the + // staging priority queue. + backup_queue.push(instance); + } + avbl_instances_.pop(); + } + // Restore the backup queue + if (!backup_queue.empty()) { + avbl_instances_.swap(backup_queue); + } +} + +void +RateLimiter::ModelContext::AddSpecificRequestQueue() +{ + std::lock_guard lk(sched_request_queue_mtx_); + specific_sched_request_queues_.emplace_back(); +} + +bool +RateLimiter::ModelContext::ContainsPendingRequests(int index) +{ + std::lock_guard lk(sched_request_queue_mtx_); + return (generic_sched_request_queue_.size() != 0) || + (specific_sched_request_queues_[index].size() != 0); +} + +void +RateLimiter::ModelContext::RequestRemoval() +{ + removal_in_progress_ = true; +} + + +//========================================================================= +// ModelInstanceContext Implementation +//========================================================================= + +RateLimiter::ModelInstanceContext::ModelInstanceContext( + TritonModelInstance* triton_model_instance, + RateLimiter::ModelContext* model_context, + const RateLimiter::RateLimiterConfig& rate_limiter_config, + RateLimiter::StandardStageFunc OnStage, + RateLimiter::StandardReleaseFunc OnRelease) + : triton_model_instance_(triton_model_instance), + index_(triton_model_instance->Index()), model_context_(model_context), + rate_limiter_config_(rate_limiter_config), OnStage_(OnStage), + OnRelease_(OnRelease), exec_count_(0), state_(AVAILABLE) +{ +} + +void +RateLimiter::ModelInstanceContext::MarkAvailable() +{ + std::lock_guard lk(state_mtx_); + state_ = AVAILABLE; +} + +Status +RateLimiter::ModelInstanceContext::Stage(StandardScheduleFunc OnSchedule) +{ + { + std::lock_guard lk(state_mtx_); + + if (state_ != AVAILABLE) { + return Status( + Status::Code::INTERNAL, + "Can not stage a model instance that is not yet available"); + } + + state_ = STAGED; + OnSchedule_ = OnSchedule; + } + + OnStage_(this); + + return Status::Success; +} + +Status +RateLimiter::ModelInstanceContext::Allocate() +{ + { + std::lock_guard lk(state_mtx_); + + if (state_ != STAGED) { + return Status( + Status::Code::INTERNAL, + "Can not allocate a model instance that is not yet staged"); + } + + state_ = ALLOCATED; + } + + OnSchedule_(this); + + return Status::Success; +} + +Status +RateLimiter::ModelInstanceContext::DirectAllocate( + StandardScheduleFunc OnSchedule) +{ + { + std::lock_guard lk(state_mtx_); + + if (state_ != AVAILABLE) { + return Status( + Status::Code::INTERNAL, + "Can not allocate a model instance that is not yet available"); + } + + state_ = ALLOCATED; + } + + OnSchedule(this); + + return Status::Success; +} + +void +RateLimiter::ModelInstanceContext::Release() +{ + exec_count_++; + + OnRelease_(this); + + { + std::lock_guard lk(state_mtx_); + if ((model_context_->isRemovalInProgress()) && (state_ == AVAILABLE) && + (!model_context_->ContainsPendingRequests(index_))) { + state_ = REMOVED; + } + } + + if (state_ == REMOVED) { + cv_.notify_all(); + } +} + +void +RateLimiter::ModelInstanceContext::RequestRemoval() +{ + std::lock_guard lk(state_mtx_); + + if ((state_ == AVAILABLE) && + (!model_context_->ContainsPendingRequests(index_))) { + state_ = REMOVED; + } +} + +void +RateLimiter::ModelInstanceContext::WaitForRemoval() +{ + if (!model_context_->isRemovalInProgress()) { + model_context_->RequestRemoval(); + } + + RequestRemoval(); + + // Wait for the instance to be removed + { + std::unique_lock lk(state_mtx_); + cv_.wait(lk, [this] { return state_ == REMOVED; }); + } +} + +double +RateLimiter::ModelInstanceContext::ScaledPriority() +{ + // TODO: Different schemes for the prioritization of + // model instance can be added here. + // The priority of instance is 1 by default. If specified + // as 0, the priority is still treated as 1. + auto priority = std::max(rate_limiter_config_.priority(), 1u); + return (exec_count_ * priority); +} + + +//========================================================================= +// ResourceManager Implementation +//========================================================================= + +Status +RateLimiter::ResourceManager::Create( + const ResourceMap& resource_map, + std::unique_ptr* resource_manager) +{ + std::unique_ptr local_resource_manager( + new ResourceManager(resource_map)); + *resource_manager = std::move(local_resource_manager); + return Status::Success; +} + +void +RateLimiter::ResourceManager::AddModelInstance( + const ModelInstanceContext* instance) +{ + std::lock_guard lk(model_resources_mtx_); + auto pr = model_resources_.emplace(std::make_pair(instance, ResourceMap())); + for (const auto& resource : instance->GetRateLimiterConfig()->resources()) { + if (resource.global()) { + (pr.first->second[GLOBAL_RESOURCE_KEY])[resource.name()] = + resource.count(); + } else { + (pr.first->second[instance->RawInstance()->DeviceId()])[resource.name()] = + resource.count(); + } + } +} + +Status +RateLimiter::ResourceManager::RemoveModelInstance( + const ModelInstanceContext* instance) +{ + std::lock_guard lk(model_resources_mtx_); + const auto& itr = model_resources_.find(instance); + if (itr == model_resources_.end()) { + return Status( + Status::Code::INTERNAL, "Can not find the instance to remove"); + } + model_resources_.erase(instance); + return Status::Success; +} + +Status +RateLimiter::ResourceManager::UpdateResourceLimits() +{ + std::lock_guard lk1(max_resources_mtx_); + std::lock_guard lk2(model_resources_mtx_); + max_resources_.clear(); + // Obtain the maximum resource across all the instances + // and use it as the default available. + for (const auto& instance_resources : model_resources_) { + for (const auto& resource_device_map : instance_resources.second) { + auto ditr = max_resources_.find(resource_device_map.first); + if (ditr == max_resources_.end()) { + ditr = + max_resources_ + .emplace(resource_device_map.first, resource_device_map.second) + .first; + } else { + for (const auto resource : resource_device_map.second) { + auto ritr = ditr->second.find(resource.first); + if (ritr == ditr->second.end()) { + ritr = ditr->second.emplace(resource.first, resource.second).first; + } else { + if (ritr->second < resource.second) { + ritr->second = resource.second; + } + } + } + } + } + } + if (!explicit_max_resources_.empty()) { + RETURN_IF_ERROR(ParseAndValidateExplicitResources()); + } + RETURN_IF_ERROR(ValidateMaxResources()); + + if (LOG_VERBOSE_IS_ON(1)) { + std::string resource_map_str{"\nMax Resource Map===>\n"}; + for (const auto& ditr : max_resources_) { + if (!ditr.second.empty()) { + std::string device_str{(ditr.first == GLOBAL_RESOURCE_KEY) + ? "GLOBAL" + : std::to_string(ditr.first)}; + resource_map_str += "\tDevice: " + device_str + "\n"; + for (const auto& ritr : ditr.second) { + resource_map_str += "\t\tResource: " + ritr.first + + "\t Count: " + std::to_string(ritr.second) + "\n"; + } + } + } + LOG_VERBOSE(1) << resource_map_str; + } + + return Status::Success; +} + +Status +RateLimiter::ResourceManager::ValidateMaxResources() +{ + for (const auto& global_resource : max_resources_[GLOBAL_RESOURCE_KEY]) { + for (const auto& ditr : max_resources_) { + if (ditr.first != GLOBAL_RESOURCE_KEY) { + for (const auto& ritr : ditr.second) { + if (global_resource.first.compare(ritr.first) == 0) { + return Status( + Status::Code::INVALID_ARG, + (std::string("Resource \"") + ritr.first + + "\" is present as both global and device-specific resource in " + "the model configuration.") + .c_str()); + } + } + } + } + } + return Status::Success; +} + +Status +RateLimiter::ResourceManager::ParseAndValidateExplicitResources() +{ + for (auto& ditr : max_resources_) { + for (auto& ritr : ditr.second) { + // If not specified explicitly, consider the resource to be unavailable. + size_t resource_count = 0; + if (ditr.first == GLOBAL_RESOURCE_KEY) { + // Ignore the device specification... will search for all resources in + // the map... + for (const auto& exp_ditr : explicit_max_resources_) { + for (const auto& exp_ritr : exp_ditr.second) { + if (ritr.first.compare(exp_ritr.first) == 0) { + if (resource_count < exp_ritr.second) { + resource_count = exp_ritr.second; + } + } + } + } + } else { + // Search only for the device specific or per-device resources... + // device-specific + for (const auto& exp_ritr : explicit_max_resources_[ditr.first]) { + if (ritr.first.compare(exp_ritr.first) == 0) { + if (resource_count < exp_ritr.second) { + resource_count = exp_ritr.second; + } + } + } + // per-device + for (const auto& exp_ritr : + explicit_max_resources_[PER_DEVICE_RESOURCE_KEY]) { + if (ritr.first.compare(exp_ritr.first) == 0) { + if (resource_count < exp_ritr.second) { + resource_count = exp_ritr.second; + } + } + } + } + if (resource_count < ritr.second) { + return Status( + Status::Code::INVALID_ARG, + (std::string("Resource count for \"") + ritr.first + + "\" is limited to " + std::to_string(resource_count) + + " which will prevent scheduling of one or more model " + "instances, the minimum required count is " + + std::to_string(ritr.second)) + .c_str()); + } else { + ritr.second = resource_count; + } + } + } + + return Status::Success; +} + +bool +RateLimiter::ResourceManager::AllocateResources( + const ModelInstanceContext* instance) +{ + std::lock_guard lk1(model_resources_mtx_); + std::lock_guard lk2(allocated_resources_mtx_); + const auto& itr = model_resources_.find(instance); + if (itr == model_resources_.end()) { + return false; + } else { + // First pass to verify if resources are available + { + std::lock_guard lk3(max_resources_mtx_); + for (const auto& ditr : itr->second) { + auto allocated_ditr = allocated_resources_.find(ditr.first); + if (allocated_ditr == allocated_resources_.end()) { + allocated_ditr = + allocated_resources_ + .emplace(ditr.first, std::map()) + .first; + } + for (const auto& ritr : ditr.second) { + auto allocated_ritr = allocated_ditr->second.find(ritr.first); + if (allocated_ritr == allocated_ditr->second.end()) { + allocated_ritr = + allocated_ditr->second.emplace(ritr.first, 0).first; + } + if ((allocated_ritr->second + ritr.second) > + (max_resources_[ditr.first])[ritr.first]) { + return false; + } + } + } + } + + // Second pass to actually allocate the resources + for (const auto& ditr : itr->second) { + for (const auto& ritr : ditr.second) { + (allocated_resources_[ditr.first])[ritr.first] += ritr.second; + } + } + } + + return true; +} + +Status +RateLimiter::ResourceManager::ReleaseResources( + const ModelInstanceContext* instance) +{ + std::lock_guard lk1(model_resources_mtx_); + std::lock_guard lk2(allocated_resources_mtx_); + const auto& itr = model_resources_.find(instance); + if (itr == model_resources_.end()) { + return Status( + Status::Code::INTERNAL, + "Unable find the instance resources to release"); + } else { + for (const auto& ditr : itr->second) { + for (const auto& ritr : ditr.second) { + (allocated_resources_[ditr.first])[ritr.first] -= ritr.second; + } + } + } + + return Status::Success; +} + +RateLimiter::ResourceManager::ResourceManager(const ResourceMap& resource_map) + : explicit_max_resources_(resource_map) +{ +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/rate_limiter.h b/3rdparty/core-r22.12/src/rate_limiter.h new file mode 100644 index 0000000000000000000000000000000000000000..3734e9bd1224e59b846cdd53d1f861a87c0d6f95 --- /dev/null +++ b/3rdparty/core-r22.12/src/rate_limiter.h @@ -0,0 +1,310 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include +#include + +#include "backend_model.h" +#include "backend_model_instance.h" +#include "instance_queue.h" +#include "model_config.pb.h" +#include "payload.h" +#include "status.h" + +namespace triton { namespace core { + +// Limits the rate at which requests are dispatched to the model instances +class RateLimiter { + public: + using RateLimiterConfig = inference::ModelRateLimiter; + using ResourceMap = std::map>; + enum RESOURCE_KIND_KEY { + // Key for holding global resources + GLOBAL_RESOURCE_KEY = -2, + // Key for holding resources per each device + PER_DEVICE_RESOURCE_KEY = -1 + }; + + /// Creates a rate limiter object which will funnel the requests to + /// the model instances. A typical lifetime of the model instance within + /// RateLimiter transition from available -> staged -> allocated -> available. + /// The transition from available to staged occurs when a request is + /// registered for the model. Depending upon the resource availabilty and + /// priority, the RateLimiter will transition an instance to allocated state + /// at some point in the future. The staged state is skipped when + /// configured to ignore the resource constraints. The cycle in this case + /// will be available -> allocated -> available. + /// \param ignore_resources_and_priority Whether or not to ignore resource + /// constraints and cross-model priority. An available instance is directly + /// allocated when true. + /// \param resource_map The map to the available resource count provided + /// explicitly. + /// \return Status object indicating success or failure. + static Status Create( + const bool ignore_resources_and_priority, const ResourceMap& resource_map, + std::unique_ptr* rate_limiter); + + /// Registers the model instance with the rate limiter. + /// \param instance The pointer to the TritonModelInstance object to register + /// with the rate limiter. + /// \param rate_limiter_config The rate limiter configuration associated with + /// the model instance. + /// \return Status object indicating success or failure. + Status RegisterModelInstance( + TritonModelInstance* instance, + const RateLimiterConfig& rate_limiter_config); + + /// Remove model from the set of models being managed by the rate limiter. + /// \param model The pointer to TritonModel object to be removed. + /// \return Status object indicating success or failure. + Status UnregisterModel(const TritonModel* model); + + /// Returns true if there is a payload slot available for the given model. + /// \param model The pointer to TritonModel object to be removed. + /// \return slot availability in boolean. + bool PayloadSlotAvailable(const TritonModel* model); + + /// Enqueues the payload to rate limiter for scheduling on the given model. + /// \param model The pointer to TritonModel object to be removed. + /// \param payload The shared pointer to the payload object. + /// \return Status object indicating success or failure. + Status EnqueuePayload( + const TritonModel* model, std::shared_ptr payload); + + /// Returns the payload that has been scheduled for the given set of model + /// instances. Note that this call is blocking and depends upon the + /// availability of payloads in the rate limiter for the triton model + /// instance. + /// \param instance The pointers to TritonModelInstance objects whose + /// payload is being requested. + /// \param payload The shared pointer to the payload object. + void DequeuePayload( + std::deque& instance, + std::shared_ptr* payload); + + /// Returns a new payload object. + /// \param op_type The operation type for the payload. + /// \param instance Optional field that providess the model instance that must + /// be used for the execution of the payload. Default is nullptr which allows + /// any model instance to execute the payload. + /// \return The shared pointer to a new payload object. + std::shared_ptr GetPayload( + const Payload::Operation op_type, + TritonModelInstance* instance = nullptr); + + /// Releases the given payload object back to the rate limiter. + /// \param payload The payload to release. + void PayloadRelease(std::shared_ptr& payload); + + private: + class ModelInstanceContext; + class ModelContext; + struct PayloadQueue; + using StandardReleaseFunc = std::function; + using StandardScheduleFunc = std::function; + using StandardStageFunc = std::function; + + // Holds the state of the model instance. + class ModelInstanceContext { + public: + friend class RateLimiter; + friend class ResourceManager; + enum State { AVAILABLE, STAGED, ALLOCATED, REMOVED }; + + void Release(); + TritonModelInstance* RawInstance() const { return triton_model_instance_; } + + private: + ModelInstanceContext( + TritonModelInstance* triton_model_instance, ModelContext* model_context, + const RateLimiterConfig& rate_limiter_config, StandardStageFunc OnStage, + StandardReleaseFunc OnRelease); + + const RateLimiterConfig* GetRateLimiterConfig() const + { + return &rate_limiter_config_; + } + void MarkAvailable(); + double ScaledPriority(); + Status Stage(StandardScheduleFunc OnSchedule); + Status Allocate(); + Status DirectAllocate(StandardScheduleFunc OnSchedule); + void RequestRemoval(); + void WaitForRemoval(); + + TritonModelInstance* triton_model_instance_; + size_t index_; + ModelContext* model_context_; + RateLimiterConfig rate_limiter_config_; + StandardStageFunc OnStage_; + StandardReleaseFunc OnRelease_; + std::atomic exec_count_; + + State state_; + bool removal_in_progress_; + std::mutex state_mtx_; + + StandardScheduleFunc OnSchedule_; + + std::condition_variable cv_; + }; + + class ScaledPriorityComparator { + public: + bool operator()(ModelInstanceContext* a, ModelInstanceContext* b) + { + return a->ScaledPriority() > b->ScaledPriority(); + } + }; + + using PriorityQueue = std::priority_queue< + ModelInstanceContext*, std::vector, + ScaledPriorityComparator>; + + // Holds the active context to a model + class ModelContext { + public: + ModelContext(); + + Status EnqueueModelInstanceRequest( + const StandardScheduleFunc& OnSchedule, + TritonModelInstance* triton_model_instance); + void AddAvailableInstance(ModelInstanceContext* instance); + void StageInstanceIfAvailable(TritonModelInstance* triton_model_instance); + void AllocateInstanceIfAvailable(); + void AddSpecificRequestQueue(); + bool ContainsPendingRequests(int32_t index); + void RequestRemoval(); + bool isRemovalInProgress() { return removal_in_progress_; } + + private: + bool removal_in_progress_; + + // Queue holding pending scheduling request + std::queue generic_sched_request_queue_; + std::vector> + specific_sched_request_queues_; + std::recursive_mutex sched_request_queue_mtx_; + + // The set of instances that are available at the moment + PriorityQueue avbl_instances_; + std::recursive_mutex avbl_instances_mtx_; + }; + + // Manages and keep track of resource allocation to the model instances. + class ResourceManager { + public: + static Status Create( + const ResourceMap& resource_map, + std::unique_ptr* resource_manager); + void AddModelInstance(const ModelInstanceContext* instance); + Status RemoveModelInstance(const ModelInstanceContext* instance); + Status UpdateResourceLimits(); + bool AllocateResources(const ModelInstanceContext* instance); + Status ReleaseResources(const ModelInstanceContext* instance); + + private: + ResourceManager(const ResourceMap& resource_map); + Status ValidateMaxResources(); + Status ParseAndValidateExplicitResources(); + + ResourceMap explicit_max_resources_; + + std::map model_resources_; + std::mutex model_resources_mtx_; + + ResourceMap max_resources_; + std::mutex max_resources_mtx_; + + ResourceMap allocated_resources_; + std::mutex allocated_resources_mtx_; + }; + + RateLimiter( + const bool ignore_resources_and_priority, + const ResourceMap& resource_map); + + void InitializePayloadQueues(const TritonModelInstance* instance); + Status DeferPayloadSchedule( + const StandardScheduleFunc& OnSchedule, const TritonModel* model, + TritonModelInstance* instance = nullptr); + void OnStage(ModelInstanceContext* instance_ptr); + void OnRelease(ModelInstanceContext* instance_ptr); + void AttemptAllocation(); + void SchedulePayload( + TritonModelInstance* tmi, PayloadQueue* payload_queue, + const std::shared_ptr& payload); + + bool ignore_resources_and_priority_; + + // Instance context for the models + std::map< + const TritonModel*, std::vector>> + model_instance_ctxs_; + std::mutex model_instance_ctx_mtx_; + + // Running context of the models + std::map model_contexts_; + std::mutex model_ctx_mtx_; + + // Holds the model instances that have been staged + PriorityQueue staged_instances_; + std::recursive_mutex staged_instances_mtx_; + + // Manager to keep track of the resource allocations + std::unique_ptr resource_manager_; + + // Mutex to serialize Payload [de]allocation + std::mutex payload_mu_; + + // Mutex to serialize Payload Queues deallocation + std::mutex payload_queues_mu_; + + // Keep some number of Payload objects for reuse to avoid the overhead + // of creating a Payload for every new request. + const size_t max_payload_bucket_count_; + std::vector> payload_bucket_; + std::deque> payloads_in_use_; + + struct PayloadQueue { + explicit PayloadQueue(size_t max_batch_size, uint64_t max_queue_delay_ns) + { + queue_.reset(new InstanceQueue(max_batch_size, max_queue_delay_ns)); + } + std::unique_ptr queue_; + std::map> + specific_queues_; + std::mutex mu_; + std::condition_variable cv_; + }; + std::map> payload_queues_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/repo_agent.cc b/3rdparty/core-r22.12/src/repo_agent.cc new file mode 100644 index 0000000000000000000000000000000000000000..c5c27e6aa3cf483b799a1332e5979dde545370bd --- /dev/null +++ b/3rdparty/core-r22.12/src/repo_agent.cc @@ -0,0 +1,573 @@ +// Copyright 2021-2022, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "repo_agent.h" + +#include +#include "filesystem.h" +#include "shared_library.h" +#include "triton/common/logging.h" +#include "tritonserver_apis.h" + +// For unknown reason, windows will not export the TRITONREPOAGENT_* +// functions declared with dllexport in tritonrepoagent.h. To get +// those functions exported it is (also?) necessary to mark the +// definitions in this file with dllexport as well. +#if defined(_MSC_VER) +#define TRITONAPI_DECLSPEC __declspec(dllexport) +#elif defined(__GNUC__) +#define TRITONAPI_DECLSPEC __attribute__((__visibility__("default"))) +#else +#define TRITONAPI_DECLSPEC +#endif + +namespace triton { namespace core { + +std::string +TritonRepoAgentLibraryName(const std::string& agent_name) +{ +#ifdef _WIN32 + return std::string("tritonrepoagent_") + agent_name + ".dll"; +#else + return std::string("libtritonrepoagent_") + agent_name + ".so"; +#endif +} + +std::string +TRITONREPOAGENT_ActionTypeString(const TRITONREPOAGENT_ActionType type) +{ + switch (type) { + case TRITONREPOAGENT_ACTION_LOAD: + return "TRITONREPOAGENT_ACTION_LOAD"; + case TRITONREPOAGENT_ACTION_LOAD_COMPLETE: + return "TRITONREPOAGENT_ACTION_LOAD_COMPLETE"; + case TRITONREPOAGENT_ACTION_LOAD_FAIL: + return "TRITONREPOAGENT_ACTION_LOAD_FAIL"; + case TRITONREPOAGENT_ACTION_UNLOAD: + return "TRITONREPOAGENT_ACTION_UNLOAD"; + case TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE: + return "TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE"; + } + return "Unknown TRITONREPOAGENT_ActionType"; +} + +std::string +TRITONREPOAGENT_ArtifactTypeString(const TRITONREPOAGENT_ArtifactType type) +{ + switch (type) { + case TRITONREPOAGENT_ARTIFACT_FILESYSTEM: + return "TRITONREPOAGENT_ARTIFACT_FILESYSTEM"; + case TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM: + return "TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM"; + } + return "Unknown TRITONREPOAGENT_ArtifactType"; +} + +// +// TritonRepoAgent +// +Status +TritonRepoAgent::Create( + const std::string& name, const std::string& libpath, + std::shared_ptr* agent) +{ + std::shared_ptr lagent(new TritonRepoAgent(name)); + + { + std::unique_ptr slib; + RETURN_IF_ERROR(SharedLibrary::Acquire(&slib)); + + RETURN_IF_ERROR(slib->OpenLibraryHandle(libpath, &lagent->dlhandle_)); + RETURN_IF_ERROR(slib->GetEntrypoint( + lagent->dlhandle_, "TRITONREPOAGENT_Initialize", true /* optional */, + reinterpret_cast(&lagent->init_fn_))); + RETURN_IF_ERROR(slib->GetEntrypoint( + lagent->dlhandle_, "TRITONREPOAGENT_Finalize", true /* optional */, + reinterpret_cast(&lagent->fini_fn_))); + RETURN_IF_ERROR(slib->GetEntrypoint( + lagent->dlhandle_, "TRITONREPOAGENT_ModelInitialize", + true /* optional */, + reinterpret_cast(&lagent->model_init_fn_))); + RETURN_IF_ERROR(slib->GetEntrypoint( + lagent->dlhandle_, "TRITONREPOAGENT_ModelFinalize", true /* optional */, + reinterpret_cast(&lagent->model_fini_fn_))); + RETURN_IF_ERROR(slib->GetEntrypoint( + lagent->dlhandle_, "TRITONREPOAGENT_ModelAction", false /* optional */, + reinterpret_cast(&lagent->model_action_fn_))); + } + + // Initialize if needed + if (lagent->init_fn_ != nullptr) { + RETURN_IF_TRITONSERVER_ERROR(lagent->init_fn_( + reinterpret_cast(lagent.get()))); + } + + *agent = std::move(lagent); + return Status::Success; +} + +TritonRepoAgent::~TritonRepoAgent() +{ + // Finalize if needed + if (fini_fn_ != nullptr) { + auto err = fini_fn_(reinterpret_cast(this)); + if (err != nullptr) { + LOG_ERROR << "~TritonRepoAgent: " + << Status( + TritonCodeToStatusCode(TRITONSERVER_ErrorCode(err)), + TRITONSERVER_ErrorMessage(err)) + .AsString(); + TRITONSERVER_ErrorDelete(err); + }; + } + + { + std::unique_ptr slib; + LOG_STATUS_ERROR(SharedLibrary::Acquire(&slib), "~TritonRepoAgent"); + LOG_STATUS_ERROR(slib->CloseLibraryHandle(dlhandle_), "~TritonRepoAgent"); + } +} + +// +// TritonRepoAgentModel +// +Status +TritonRepoAgentModel::Create( + const TRITONREPOAGENT_ArtifactType type, const std::string& location, + const inference::ModelConfig& config, + const std::shared_ptr& agent, + const TritonRepoAgent::Parameters& agent_parameters, + std::unique_ptr* agent_model) +{ + std::unique_ptr lagent_model(new TritonRepoAgentModel( + type, location, config, agent, agent_parameters)); + if (agent->AgentModelInitFn() != nullptr) { + RETURN_IF_TRITONSERVER_ERROR(agent->AgentModelInitFn()( + reinterpret_cast(agent.get()), + reinterpret_cast(lagent_model.get()))); + } + *agent_model = std::move(lagent_model); + return Status::Success; +} + +TritonRepoAgentModel::~TritonRepoAgentModel() +{ + // Need to ensure the proper lifecycle is informed + if (action_type_set_) { + switch (current_action_type_) { + case TRITONREPOAGENT_ACTION_LOAD: + LOG_TRITONSERVER_ERROR( + agent_->AgentModelActionFn()( + reinterpret_cast(agent_.get()), + reinterpret_cast(this), + TRITONREPOAGENT_ACTION_LOAD_FAIL), + "Inform TRITONREPOAGENT_ACTION_LOAD_FAIL"); + break; + case TRITONREPOAGENT_ACTION_LOAD_COMPLETE: + LOG_TRITONSERVER_ERROR( + agent_->AgentModelActionFn()( + reinterpret_cast(agent_.get()), + reinterpret_cast(this), + TRITONREPOAGENT_ACTION_UNLOAD), + "Inform TRITONREPOAGENT_ACTION_UNLOAD"); + // Fallthough is not yet an language feature until C++17 + LOG_TRITONSERVER_ERROR( + agent_->AgentModelActionFn()( + reinterpret_cast(agent_.get()), + reinterpret_cast(this), + TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE), + "Inform TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE"); + break; + case TRITONREPOAGENT_ACTION_UNLOAD: + LOG_TRITONSERVER_ERROR( + agent_->AgentModelActionFn()( + reinterpret_cast(agent_.get()), + reinterpret_cast(this), + TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE), + "Inform TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE"); + break; + case TRITONREPOAGENT_ACTION_LOAD_FAIL: + case TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE: + break; + } + } + if (agent_->AgentModelFiniFn() != nullptr) { + LOG_TRITONSERVER_ERROR( + agent_->AgentModelFiniFn()( + reinterpret_cast(agent_.get()), + reinterpret_cast(this)), + "~TritonRepoAgentModel"); + } + if (!acquired_location_.empty()) { + DeleteMutableLocation(); + } +} + +Status +TritonRepoAgentModel::InvokeAgent(const TRITONREPOAGENT_ActionType action_type) +{ + if ((!action_type_set_) && (action_type != TRITONREPOAGENT_ACTION_LOAD)) { + return Status( + Status::Code::INTERNAL, + "Unexpected lifecycle start state " + + TRITONREPOAGENT_ActionTypeString(action_type)); + } + switch (action_type) { + case TRITONREPOAGENT_ACTION_LOAD: + if (action_type_set_) { + return Status( + Status::Code::INTERNAL, + "Unexpected lifecycle state transition from " + + TRITONREPOAGENT_ActionTypeString(current_action_type_) + + " to " + TRITONREPOAGENT_ActionTypeString(action_type)); + } + break; + case TRITONREPOAGENT_ACTION_LOAD_COMPLETE: + case TRITONREPOAGENT_ACTION_LOAD_FAIL: + if (current_action_type_ != TRITONREPOAGENT_ACTION_LOAD) { + return Status( + Status::Code::INTERNAL, + "Unexpected lifecycle state transition from " + + TRITONREPOAGENT_ActionTypeString(current_action_type_) + + " to " + TRITONREPOAGENT_ActionTypeString(action_type)); + } + break; + case TRITONREPOAGENT_ACTION_UNLOAD: + if (current_action_type_ != TRITONREPOAGENT_ACTION_LOAD_COMPLETE) { + return Status( + Status::Code::INTERNAL, + "Unexpected lifecycle state transition from " + + TRITONREPOAGENT_ActionTypeString(current_action_type_) + + " to " + TRITONREPOAGENT_ActionTypeString(action_type)); + } + break; + case TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE: + if (current_action_type_ != TRITONREPOAGENT_ACTION_UNLOAD) { + return Status( + Status::Code::INTERNAL, + "Unexpected lifecycle state transition from " + + TRITONREPOAGENT_ActionTypeString(current_action_type_) + + " to " + TRITONREPOAGENT_ActionTypeString(action_type)); + } + break; + } + current_action_type_ = action_type; + action_type_set_ = true; + RETURN_IF_TRITONSERVER_ERROR(agent_->AgentModelActionFn()( + reinterpret_cast(agent_.get()), + reinterpret_cast(this), action_type)); + return Status::Success; +} + +Status +TritonRepoAgentModel::SetLocation( + const TRITONREPOAGENT_ArtifactType type, const std::string& location) +{ + if (current_action_type_ != TRITONREPOAGENT_ACTION_LOAD) { + return Status( + Status::Code::INVALID_ARG, + "location can only be updated during TRITONREPOAGENT_ACTION_LOAD, " + "current action type is " + + (action_type_set_ + ? TRITONREPOAGENT_ActionTypeString(current_action_type_) + : "not set")); + } + type_ = type; + location_ = location; + return Status::Success; +} + +Status +TritonRepoAgentModel::Location( + TRITONREPOAGENT_ArtifactType* type, const char** location) +{ + if (location_.empty()) { + return Status( + Status::Code::INTERNAL, "Model repository location is not set"); + } + *type = type_; + *location = location_.c_str(); + return Status::Success; +} + +Status +TritonRepoAgentModel::AcquireMutableLocation( + const TRITONREPOAGENT_ArtifactType type, const char** location) +{ + if (type != TRITONREPOAGENT_ARTIFACT_FILESYSTEM) { + return Status( + Status::Code::INVALID_ARG, + "Unexpected artifact type, expects " + "'TRITONREPOAGENT_ARTIFACT_FILESYSTEM'"); + } + if (acquired_location_.empty()) { + std::string lacquired_location; + RETURN_IF_ERROR( + MakeTemporaryDirectory(FileSystemType::LOCAL, &lacquired_location)); + acquired_location_.swap(lacquired_location); + acquired_type_ = type; + } + *location = acquired_location_.c_str(); + return Status::Success; +} + +Status +TritonRepoAgentModel::DeleteMutableLocation() +{ + if (acquired_location_.empty()) { + return Status( + Status::Code::UNAVAILABLE, "No mutable location to be deleted"); + } + + auto status = DeletePath(acquired_location_); + if (!status.IsOk()) { + LOG_ERROR << "Failed to delete previously acquired location '" + << acquired_location_ << "': " << status.AsString(); + } + acquired_location_.clear(); + return Status::Success; +} + +// +// TritonRepoAgentManager +// +TritonRepoAgentManager& +TritonRepoAgentManager::Singleton() +{ + static TritonRepoAgentManager triton_repo_agent_manager; + return triton_repo_agent_manager; +} + +Status +TritonRepoAgentManager::SetGlobalSearchPath(const std::string& path) +{ + auto& singleton_manager = Singleton(); + std::lock_guard lock(singleton_manager.mu_); + singleton_manager.global_search_path_ = path; + return Status::Success; +} + +Status +TritonRepoAgentManager::CreateAgent( + const std::string& agent_name, std::shared_ptr* agent) +{ + auto& singleton_manager = Singleton(); + std::lock_guard lock(singleton_manager.mu_); + + // Get the path to the agent shared library. Search path is global + // agent directory. FIXME expose global path as Triton option + const std::vector search_paths = { + JoinPath({singleton_manager.global_search_path_, agent_name})}; + + std::string agent_libname = TritonRepoAgentLibraryName(agent_name); + std::string libpath; + for (const auto& path : search_paths) { + const auto full_path = JoinPath({path, agent_libname}); + bool exists = false; + RETURN_IF_ERROR(FileExists(full_path, &exists)); + if (exists) { + libpath = full_path; + break; + } + } + + if (libpath.empty()) { + return Status( + Status::Code::INVALID_ARG, + "unable to find '" + agent_libname + "' for repo agent '" + agent_name + + "', searched: " + singleton_manager.global_search_path_); + } + + const auto& itr = singleton_manager.agent_map_.find(libpath); + if (itr != singleton_manager.agent_map_.end()) { + // Found in map. If the weak_ptr is still valid that means that + // there are other models using the agent and we just reuse that + // same agent. If the weak_ptr is not valid then agent has been + // unloaded so we need to remove the weak_ptr from the map and + // create the agent again. + *agent = itr->second.lock(); + if (*agent != nullptr) { + return Status::Success; + } + + singleton_manager.agent_map_.erase(itr); + } + RETURN_IF_ERROR(TritonRepoAgent::Create(agent_name, libpath, agent)); + singleton_manager.agent_map_.insert({libpath, *agent}); + + return Status::Success; +} + +Status +TritonRepoAgentManager::AgentState( + std::unique_ptr>* agent_state) +{ + auto& singleton_manager = Singleton(); + std::lock_guard lock(singleton_manager.mu_); + + std::unique_ptr> agent_state_map( + new std::unordered_map); + for (const auto& agent_pair : singleton_manager.agent_map_) { + auto& libpath = agent_pair.first; + auto agent = agent_pair.second.lock(); + + if (agent != nullptr) { + agent_state_map->insert({agent->Name(), libpath}); + } + } + + *agent_state = std::move(agent_state_map); + + return Status::Success; +} + +extern "C" { + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONREPOAGENT_ApiVersion(uint32_t* major, uint32_t* minor) +{ + *major = TRITONREPOAGENT_API_VERSION_MAJOR; + *minor = TRITONREPOAGENT_API_VERSION_MINOR; + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONREPOAGENT_ModelRepositoryLocation( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + TRITONREPOAGENT_ArtifactType* artifact_type, const char** location) +{ + TritonRepoAgentModel* tam = reinterpret_cast(model); + RETURN_TRITONSERVER_ERROR_IF_ERROR(tam->Location(artifact_type, location)); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONREPOAGENT_ModelRepositoryLocationAcquire( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ArtifactType artifact_type, const char** location) +{ + TritonRepoAgentModel* tam = reinterpret_cast(model); + RETURN_TRITONSERVER_ERROR_IF_ERROR( + tam->AcquireMutableLocation(artifact_type, location)); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONREPOAGENT_ModelRepositoryLocationRelease( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const char* location) +{ + TritonRepoAgentModel* tam = reinterpret_cast(model); + RETURN_TRITONSERVER_ERROR_IF_ERROR(tam->DeleteMutableLocation()); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONREPOAGENT_ModelRepositoryUpdate( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ArtifactType artifact_type, const char* location) +{ + TritonRepoAgentModel* tam = reinterpret_cast(model); + RETURN_TRITONSERVER_ERROR_IF_ERROR(tam->SetLocation(artifact_type, location)); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONREPOAGENT_ModelParameterCount( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + uint32_t* count) +{ + TritonRepoAgentModel* tam = reinterpret_cast(model); + *count = tam->AgentParameters().size(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONREPOAGENT_ModelParameter( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const uint32_t index, const char** parameter_name, + const char** parameter_value) +{ + TritonRepoAgentModel* tam = reinterpret_cast(model); + const auto& params = tam->AgentParameters(); + if (index >= params.size()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "index out of range for model parameters"); + } + *parameter_name = params[index].first.c_str(); + *parameter_value = params[index].second.c_str(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONREPOAGENT_ModelConfig( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const uint32_t config_version, TRITONSERVER_Message** model_config) +{ + TritonRepoAgentModel* tam = reinterpret_cast(model); + std::string model_config_json; + RETURN_TRITONSERVER_ERROR_IF_ERROR( + ModelConfigToJson(tam->Config(), config_version, &model_config_json)); + return TRITONSERVER_MessageNewFromSerializedJson( + model_config, model_config_json.c_str(), model_config_json.length()); +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONREPOAGENT_ModelState(TRITONREPOAGENT_AgentModel* model, void** state) +{ + TritonRepoAgentModel* tam = reinterpret_cast(model); + *state = tam->State(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONREPOAGENT_ModelSetState(TRITONREPOAGENT_AgentModel* model, void* state) +{ + TritonRepoAgentModel* tam = reinterpret_cast(model); + tam->SetState(state); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONREPOAGENT_State(TRITONREPOAGENT_Agent* agent, void** state) +{ + TritonRepoAgent* ta = reinterpret_cast(agent); + *state = ta->State(); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONREPOAGENT_SetState(TRITONREPOAGENT_Agent* agent, void* state) +{ + TritonRepoAgent* ta = reinterpret_cast(agent); + ta->SetState(state); + return nullptr; // success +} + +} // extern C + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/repo_agent.h b/3rdparty/core-r22.12/src/repo_agent.h new file mode 100644 index 0000000000000000000000000000000000000000..001b6f7406cb083a4e481c7e03c0823bf98a8721 --- /dev/null +++ b/3rdparty/core-r22.12/src/repo_agent.h @@ -0,0 +1,182 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include "tritonserver_apis.h" + +#include +#include +#include +#include +#include "constants.h" +#include "model_config_utils.h" + +namespace triton { namespace core { + +std::string TritonRepoAgentLibraryName(const std::string& agent_name); + +std::string TRITONREPOAGENT_ActionTypeString( + const TRITONREPOAGENT_ActionType type); + +std::string TRITONREPOAGENT_ArtifactTypeString( + const TRITONREPOAGENT_ArtifactType type); + +class TritonRepoAgent { + public: + using Parameters = std::vector>; + typedef TRITONSERVER_Error* (*TritonRepoAgentInitFn_t)( + TRITONREPOAGENT_Agent* agent); + typedef TRITONSERVER_Error* (*TritonRepoAgentFiniFn_t)( + TRITONREPOAGENT_Agent* agent); + typedef TRITONSERVER_Error* (*TritonRepoAgentModelInitFn_t)( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model); + typedef TRITONSERVER_Error* (*TritonRepoAgentModelFiniFn_t)( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model); + typedef TRITONSERVER_Error* (*TritonRepoAgentModelActionFn_t)( + TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ActionType action_type); + + static Status Create( + const std::string& name, const std::string& libpath, + std::shared_ptr* agent); + ~TritonRepoAgent(); + + const std::string& Name() { return name_; } + void* State() { return state_; } + void SetState(void* state) { state_ = state; } + + TritonRepoAgentModelActionFn_t AgentModelActionFn() const + { + return model_action_fn_; + } + + TritonRepoAgentModelInitFn_t AgentModelInitFn() const + { + return model_init_fn_; + } + + TritonRepoAgentModelFiniFn_t AgentModelFiniFn() const + { + return model_fini_fn_; + } + + protected: + DISALLOW_COPY_AND_ASSIGN(TritonRepoAgent); + + TritonRepoAgent(const std::string& name) + : name_(name), state_(nullptr), dlhandle_(nullptr), init_fn_(nullptr), + fini_fn_(nullptr), model_init_fn_(nullptr), model_fini_fn_(nullptr), + model_action_fn_(nullptr) + { + } + const std::string name_; + void* state_; + + // dlopen / dlsym handles + void* dlhandle_; + TritonRepoAgentInitFn_t init_fn_; + TritonRepoAgentFiniFn_t fini_fn_; + TritonRepoAgentModelInitFn_t model_init_fn_; + TritonRepoAgentModelFiniFn_t model_fini_fn_; + TritonRepoAgentModelActionFn_t model_action_fn_; +}; + +class TritonRepoAgentModel { + public: + static Status Create( + const TRITONREPOAGENT_ArtifactType type, const std::string& location, + const inference::ModelConfig& config, + const std::shared_ptr& agent, + const TritonRepoAgent::Parameters& agent_parameters, + std::unique_ptr* agent_model); + ~TritonRepoAgentModel(); + + void* State() { return state_; } + void SetState(void* state) { state_ = state; } + + Status InvokeAgent(const TRITONREPOAGENT_ActionType action_type); + const TritonRepoAgent::Parameters& AgentParameters() + { + return agent_parameters_; + } + + Status SetLocation( + const TRITONREPOAGENT_ArtifactType type, const std::string& location); + Status Location(TRITONREPOAGENT_ArtifactType* type, const char** location); + Status AcquireMutableLocation( + const TRITONREPOAGENT_ArtifactType type, const char** location); + Status DeleteMutableLocation(); + const inference::ModelConfig Config() { return config_; } + + private: + DISALLOW_COPY_AND_ASSIGN(TritonRepoAgentModel); + + TritonRepoAgentModel( + const TRITONREPOAGENT_ArtifactType type, const std::string& location, + const inference::ModelConfig& config, + const std::shared_ptr& agent, + const TritonRepoAgent::Parameters& agent_parameters) + : state_(nullptr), config_(config), agent_(agent), + agent_parameters_(agent_parameters), type_(type), location_(location), + action_type_set_(false), + current_action_type_(TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE) + { + } + + void* state_; + const inference::ModelConfig config_; + const std::shared_ptr agent_; + const TritonRepoAgent::Parameters agent_parameters_; + TRITONREPOAGENT_ArtifactType type_; + std::string location_; + TRITONREPOAGENT_ArtifactType acquired_type_; + std::string acquired_location_; + bool action_type_set_; + TRITONREPOAGENT_ActionType current_action_type_; +}; + +class TritonRepoAgentManager { + public: + static Status SetGlobalSearchPath(const std::string& path); + static Status CreateAgent( + const std::string& agent_name, std::shared_ptr* agent); + + static Status AgentState( + std::unique_ptr>* + agent_state); + + private: + DISALLOW_COPY_AND_ASSIGN(TritonRepoAgentManager); + + TritonRepoAgentManager() + : global_search_path_("/opt/tritonserver/repoagents"){}; + static TritonRepoAgentManager& Singleton(); + std::mutex mu_; + std::string global_search_path_; + std::unordered_map> agent_map_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/response_allocator.h b/3rdparty/core-r22.12/src/response_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..143cc7ff877cc141c1a4110ebf03fd1f2a82227e --- /dev/null +++ b/3rdparty/core-r22.12/src/response_allocator.h @@ -0,0 +1,77 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include "tritonserver_apis.h" + +namespace triton { namespace core { + +// +// Implementation for TRITONSERVER_ResponseAllocator. +// +class ResponseAllocator { + public: + explicit ResponseAllocator( + TRITONSERVER_ResponseAllocatorAllocFn_t alloc_fn, + TRITONSERVER_ResponseAllocatorReleaseFn_t release_fn, + TRITONSERVER_ResponseAllocatorStartFn_t start_fn) + : alloc_fn_(alloc_fn), buffer_attributes_fn_(nullptr), query_fn_(nullptr), + release_fn_(release_fn), start_fn_(start_fn) + { + } + + void SetQueryFunction(TRITONSERVER_ResponseAllocatorQueryFn_t query_fn) + { + query_fn_ = query_fn; + } + + void SetBufferAttributesFunction( + TRITONSERVER_ResponseAllocatorBufferAttributesFn_t buffer_attributes_fn) + { + buffer_attributes_fn_ = buffer_attributes_fn; + } + + TRITONSERVER_ResponseAllocatorAllocFn_t AllocFn() const { return alloc_fn_; } + TRITONSERVER_ResponseAllocatorBufferAttributesFn_t BufferAttributesFn() const + { + return buffer_attributes_fn_; + } + TRITONSERVER_ResponseAllocatorQueryFn_t QueryFn() const { return query_fn_; } + TRITONSERVER_ResponseAllocatorReleaseFn_t ReleaseFn() const + { + return release_fn_; + } + TRITONSERVER_ResponseAllocatorStartFn_t StartFn() const { return start_fn_; } + + private: + TRITONSERVER_ResponseAllocatorAllocFn_t alloc_fn_; + TRITONSERVER_ResponseAllocatorBufferAttributesFn_t buffer_attributes_fn_; + TRITONSERVER_ResponseAllocatorQueryFn_t query_fn_; + TRITONSERVER_ResponseAllocatorReleaseFn_t release_fn_; + TRITONSERVER_ResponseAllocatorStartFn_t start_fn_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/response_cache.cc b/3rdparty/core-r22.12/src/response_cache.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff5f0707accec7aaf2302503f968889b9883dba4 --- /dev/null +++ b/3rdparty/core-r22.12/src/response_cache.cc @@ -0,0 +1,542 @@ +// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "response_cache.h" +#include "infer_stats.h" +#include "triton/common/logging.h" + +namespace { + +enum class ScopedTimerType { INSERTION, LOOKUP }; + +class ScopedTimer { + public: + explicit ScopedTimer( + triton::core::InferenceRequest& request, uint64_t& duration, + ScopedTimerType type) + : request_(request), duration_(duration), type_(type) + { + switch (type_) { + case ScopedTimerType::LOOKUP: + request_.CaptureCacheLookupStartNs(); + break; + case ScopedTimerType::INSERTION: + request_.CaptureCacheInsertionStartNs(); + break; + } + } + + ~ScopedTimer() + { + switch (type_) { + case ScopedTimerType::LOOKUP: + request_.CaptureCacheLookupEndNs(); + duration_ += + request_.CacheLookupEndNs() - request_.CacheLookupStartNs(); + break; + case ScopedTimerType::INSERTION: + request_.CaptureCacheInsertionEndNs(); + duration_ += + request_.CacheInsertionEndNs() - request_.CacheInsertionStartNs(); + break; + } + } + + private: + triton::core::InferenceRequest& request_; + uint64_t& duration_; + ScopedTimerType type_; +}; + +std::string +PointerToString(void* ptr) +{ + std::stringstream ss; + ss << ptr; + return ss.str(); +} + +} // namespace + +namespace triton { namespace core { + +Status +RequestResponseCache::Create( + uint64_t cache_size, std::unique_ptr* cache) +{ + try { + cache->reset(new RequestResponseCache(cache_size)); + } + catch (const std::exception& ex) { + return Status( + Status::Code::INTERNAL, + "Failed to initialize Response Cache: " + std::string(ex.what())); + } + + return Status::Success; +} + +RequestResponseCache::RequestResponseCache(const uint64_t size) +{ + // Allocate buffer + buffer_ = malloc(size); + // Exit early if buffer allocation failed + if (buffer_ == nullptr) { + throw std::runtime_error("failed to allocate buffer"); + } + + // Create cache as managed buffer + managed_buffer_ = boost::interprocess::managed_external_buffer( + boost::interprocess::create_only_t{}, buffer_, size); + + LOG_INFO << "Response Cache is created at '" << PointerToString(buffer_) + << "' with size " << size; +} + +RequestResponseCache::~RequestResponseCache() +{ + // Deallocate each chunk from managed buffer + for (auto& iter : cache_) { + auto& entry = iter.second; + for (auto& output : entry.outputs_) { + if (output.buffer_ != nullptr) { + managed_buffer_.deallocate(output.buffer_); + } + } + } + + // Validate we freed all underlying memory managed by cache + if (!managed_buffer_.all_memory_deallocated()) { + // Destructors can't throw exceptions + LOG_ERROR << "failed to free managed cache memory"; + } + + // Free total cache buffer + if (buffer_ != nullptr) { + free(buffer_); + } +} + +Status +RequestResponseCache::Lookup( + InferenceResponse* const response, InferenceRequest* const request) +{ + // Lock on cache lookup + std::lock_guard lk(cache_mtx_); + + if (request == nullptr) { + return Status( + Status::Code::INTERNAL, "Cache Lookup passed a nullptr request"); + } + + // Capture start latency now and end latency when timer goes out of scope + ScopedTimer timer( + *request, total_lookup_latency_ns_, ScopedTimerType::LOOKUP); + + // Hash the request and set cache key if it hasn't already been set + if (!request->CacheKeyIsSet()) { + RETURN_IF_ERROR(HashAndSet(request)); + } + const uint64_t key = request->CacheKey(); + + num_lookups_++; + LOG_VERBOSE(1) << request->LogRequest() + << "Looking up key [" + std::to_string(key) + "] in cache."; + + // Search cache for request hash key + auto iter = cache_.find(key); + if (iter == cache_.end()) { + num_misses_++; + LOG_VERBOSE(1) << request->LogRequest() + << "MISS for key [" + std::to_string(key) + "] in cache."; + return Status( + Status::Code::INTERNAL, + request->LogRequest() + "key not found in cache"); + } + + // If find succeeds, it's a cache hit + num_hits_++; + LOG_VERBOSE(1) << request->LogRequest() + << "HIT for key [" + std::to_string(key) + "] in cache."; + + // Populate passed-in "response" from cache entry + auto entry = iter->second; + // Build InferenceResponse from CacheEntry + RETURN_IF_ERROR(BuildInferenceResponse(entry, response)); + + // Update this key to front of LRU list + UpdateLRU(iter); + LOG_VERBOSE(1) << request->LogRequest() + << "Using cached response for key [" + std::to_string(key) + + "]."; + return Status::Success; +} + +Status +RequestResponseCache::Insert( + const InferenceResponse& response, InferenceRequest* const request) +{ + // Lock on cache insertion + std::lock_guard lk(cache_mtx_); + + if (request == nullptr) { + return Status( + Status::Code::INTERNAL, "Cache Insert passed a nullptr request"); + } + + // Capture start latency now and end latency when timer goes out of scope + ScopedTimer timer( + *request, total_insertion_latency_ns_, ScopedTimerType::INSERTION); + + // Hash the request and set cache key if it hasn't already been set + if (!request->CacheKeyIsSet()) { + RETURN_IF_ERROR(HashAndSet(request)); + } + const uint64_t key = request->CacheKey(); + + // Exit early if key already exists in cache + auto iter = cache_.find(key); + if (iter != cache_.end()) { + return Status( + Status::Code::ALREADY_EXISTS, request->LogRequest() + "key [" + + std::to_string(key) + + "] already exists in cache"); + } + + // Construct cache entry from response + auto entry = CacheEntry(); + RETURN_IF_ERROR(BuildCacheEntry(response, &entry)); + + // Insert entry into cache + LOG_VERBOSE(1) << request->LogRequest() + << "Inserting key [" + std::to_string(key) + "] into cache."; + auto cache_pair = cache_.insert({key, entry}); + // Exit early if cache insertion failed + if (!cache_pair.second) { + LOG_ERROR << request->LogRequest() << "Failed to insert key into map."; + return Status( + Status::Code::INTERNAL, + request->LogRequest() + "Cache insertion failed"); + } + // Update LRU with new cache entry + auto cache_iter = cache_pair.first; + UpdateLRU(cache_iter); + + return Status::Success; +} + +// LRU +Status +RequestResponseCache::Evict() +{ + // Lock on cache eviction + std::lock_guard lk(cache_mtx_); + + // Nothing to evict if cache is empty + if (NumEntries() == 0) { + return Status(Status::Code::INTERNAL, "Cache is empty, nothing to evict."); + } + + // Least recently used key in back of LRU list + uint64_t lru_key = lru_.back(); + LOG_VERBOSE(1) << "Evicting key [" + std::to_string(lru_key) + + "] from cache."; + + // Find cache entry for least recently used key + auto iter = cache_.find(lru_key); + // Error check if key isn't in cache, but this shouldn't happen in evict + // and probably indicates a bug + if (iter == cache_.end()) { + return Status( + Status::Code::INTERNAL, + "key [" + std::to_string(lru_key) + + "] not found in cache during eviction: this indicates a bug in the " + "code"); + } + // Get size of cache entry being evicted to update available size + auto entry = iter->second; + // Free managed memory used in cache entry's outputs + for (auto& output : entry.outputs_) { + // Lock on buffer deallocation + std::lock_guard lk(buffer_mtx_); + managed_buffer_.deallocate(output.buffer_); + } + + // Remove LRU entry from cache + cache_.erase(lru_key); + // Remove LRU key from LRU list + lru_.pop_back(); + // Increment number of evictions + num_evictions_++; + + return Status::Success; +} + +// Helpers +void +RequestResponseCache::UpdateLRU( + std::unordered_map::iterator& cache_iter) +{ + // Lock on cache update + std::lock_guard lk(cache_mtx_); + + const auto& key = cache_iter->first; + auto& cache_entry = cache_iter->second; + // Remove key from LRU list if it was already in there + auto lru_iter = std::find(lru_.begin(), lru_.end(), key); + if (lru_iter != lru_.end()) { + lru_.erase(lru_iter); + } + // Add key to front of LRU list since it's most recently used + lru_.push_front(key); + // Set CacheEntry LRU iterator to new LRU key location + cache_entry.lru_iter_ = lru_.begin(); +} + +Status +RequestResponseCache::BuildCacheEntry( + const InferenceResponse& response, CacheEntry* const entry) +{ + // Build cache entry data from response outputs + for (const auto& response_output : response.Outputs()) { + auto cache_output = Output(); + + // Fetch output buffer details + const void* response_buffer = nullptr; + size_t response_byte_size = 0; + TRITONSERVER_MemoryType response_memory_type; + int64_t response_memory_type_id; + void* userp; + RETURN_IF_ERROR(response_output.DataBuffer( + &response_buffer, &response_byte_size, &response_memory_type, + &response_memory_type_id, &userp)); + + // TODO: Handle other memory types + if (response_memory_type != TRITONSERVER_MEMORY_CPU && + response_memory_type != TRITONSERVER_MEMORY_CPU_PINNED) { + return Status( + Status::Code::INTERNAL, + "Only input buffers in CPU memory are allowed in cache currently"); + } + + // Exit early if response buffer from output is invalid + if (response_buffer == nullptr) { + return Status( + Status::Code::INTERNAL, "Response buffer from output was nullptr"); + } + + // Lock on managed buffer references + { + std::lock_guard lk(buffer_mtx_); + + // Exit early if cache entry will be larger than available cache size + if (response_byte_size > managed_buffer_.get_size()) { + return Status( + Status::Code::INTERNAL, + "Cache entry is larger than total cache size"); + } + + // If cache doesn't have enough space, evict until enough space available + // NOTE: FreeBytes() doesn't account for allocator overhead so allocation + // may fail even if response_byte_size is less than FreeBytes() + while (response_byte_size > FreeBytes()) { + LOG_VERBOSE(1) << "EVICT: Response larger than remaining available " + "memory, attempting to evict from cache."; + RETURN_IF_ERROR(Evict()); + } + + // Attempt to allocate buffer until success or eviction from cache fails + while (cache_output.buffer_ == nullptr) { + // Allocate buffer for response output in cache entry + cache_output.buffer_ = + managed_buffer_.allocate(response_byte_size, std::nothrow_t{}); + // Attempt to evict if allocation fails + if (cache_output.buffer_ == nullptr) { + LOG_VERBOSE(1) << "FAILED to allocate buffer in cache. Attempting to " + "evict an entry."; + // Exit out if Eviction fails + RETURN_IF_ERROR(Evict()); + } + } + + // Copy data from response buffer to cache entry output buffer + // TODO: Handle other memory types + std::memcpy(cache_output.buffer_, response_buffer, response_byte_size); + + // Set output metadata + cache_output.name_ = response_output.Name(); + cache_output.dtype_ = response_output.DType(); + cache_output.shape_ = response_output.Shape(); + cache_output.buffer_size_ = static_cast(response_byte_size); + } + + // Add each output to cache entry + entry->outputs_.push_back(cache_output); + } + + return Status::Success; +} + + +Status +RequestResponseCache::BuildInferenceResponse( + const CacheEntry& entry, InferenceResponse* const response) +{ + if (response == nullptr) { + return Status(Status::Code::INTERNAL, "invalid response ptr passed in"); + } + + // Lock on cache references + { + std::lock_guard lk(cache_mtx_); + + // Inference response outputs should be empty so we can append to them + if (response->Outputs().size() != 0) { + return Status( + Status::Code::INTERNAL, + "InferenceResponse already contains some outputs"); + } + + for (auto& cache_output : entry.outputs_) { + InferenceResponse::Output* response_output = nullptr; + RETURN_IF_ERROR(response->AddOutput( + cache_output.name_, cache_output.dtype_, cache_output.shape_, + &response_output)); + + if (response_output == nullptr) { + return Status( + Status::Code::INTERNAL, + "InferenceResponse::Output pointer as nullptr"); + } + + TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; + int64_t memory_type_id = 0; + + // Allocate buffer for inference response + void* buffer; + RETURN_IF_ERROR(response_output->AllocateDataBuffer( + &buffer, cache_output.buffer_size_, &memory_type, &memory_type_id)); + + // TODO: Handle other memory types + if (memory_type != TRITONSERVER_MEMORY_CPU && + memory_type != TRITONSERVER_MEMORY_CPU_PINNED) { + return Status( + Status::Code::INTERNAL, + "Only input buffers in CPU memory are allowed in cache currently"); + } + + if (buffer == nullptr) { + return Status( + Status::Code::INTERNAL, "failed to allocate buffer for output '" + + cache_output.name_ + "'"); + } + // Copy cached output buffer to allocated response output buffer + std::memcpy(buffer, cache_output.buffer_, cache_output.buffer_size_); + + // TODO: Add field to InferenceResponse to indicate this was from cache + // response.cached = true; + } + } + + return Status::Success; +} + +Status +RequestResponseCache::HashInputBuffers( + const InferenceRequest::Input* input, size_t* seed) +{ + // Iterate over each data buffer in input in case of non-contiguous memory + for (size_t idx = 0; idx < input->DataBufferCount(); ++idx) { + const void* src_buffer; + size_t src_byte_size; + TRITONSERVER_MemoryType src_memory_type; + int64_t src_memory_type_id; + + RETURN_IF_ERROR(input->DataBuffer( + idx, &src_buffer, &src_byte_size, &src_memory_type, + &src_memory_type_id)); + + // TODO: Handle other memory types + if (src_memory_type != TRITONSERVER_MEMORY_CPU && + src_memory_type != TRITONSERVER_MEMORY_CPU_PINNED) { + return Status( + Status::Code::INTERNAL, + "Only input buffers in CPU memory are allowed in cache currently"); + } + + // Add each byte of input buffer chunk to hash + const unsigned char* tmp = static_cast(src_buffer); + for (uint64_t byte = 0; byte < src_byte_size; byte++) { + boost::hash_combine(*seed, tmp[byte]); + } + } + + return Status::Success; +} + + +Status +RequestResponseCache::HashInputs(const InferenceRequest& request, size_t* seed) +{ + const auto& inputs = request.ImmutableInputs(); + // Convert inputs to ordered map for consistency in hashing + // inputs sorted by key (input) name + std::map ordered_inputs( + inputs.begin(), inputs.end()); + for (const auto& input : ordered_inputs) { + // Add input name to hash + boost::hash_combine(*seed, input.second->Name()); + // Fetch input buffer for hashing raw data + RETURN_IF_ERROR(HashInputBuffers(input.second, seed)); + } + + return Status::Success; +} + + +Status +RequestResponseCache::Hash(const InferenceRequest& request, uint64_t* key) +{ + std::size_t seed = 0; + // Add request model name to hash + boost::hash_combine(seed, request.ModelName()); + // Add request model version to hash + boost::hash_combine(seed, request.ActualModelVersion()); + RETURN_IF_ERROR(HashInputs(request, &seed)); + *key = static_cast(seed); + return Status::Success; +} + +Status +RequestResponseCache::HashAndSet(InferenceRequest* const request) +{ + uint64_t key = 0; + RETURN_IF_ERROR(Hash(*request, &key)); + request->SetCacheKey(key); + return Status::Success; +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/response_cache.h b/3rdparty/core-r22.12/src/response_cache.h new file mode 100644 index 0000000000000000000000000000000000000000..6c39655e67d392c2b1541fe15de3e983b1a70c85 --- /dev/null +++ b/3rdparty/core-r22.12/src/response_cache.h @@ -0,0 +1,198 @@ +// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include +#include + +#include "infer_request.h" +#include "infer_response.h" +#include "model.h" +#include "status.h" + +#include +#include + +namespace triton { namespace core { + +// Assuming CPU memory only for now +struct Output { + // Output tensor data buffer + void* buffer_; + // Size of "buffer" above + uint64_t buffer_size_ = 0; + // Name of the output + std::string name_; + // Datatype of the output + inference::DataType dtype_; + // Shape of the output + std::vector shape_; +}; + +struct CacheEntry { + explicit CacheEntry() {} + // Point to key in LRU list for maintaining LRU order + std::list::iterator lru_iter_; + // each output buffer = managed_buffer.allocate(size, ...) + std::vector outputs_; +}; + +class RequestResponseCache { + public: + ~RequestResponseCache(); + // Create the request/response cache object + static Status Create( + uint64_t cache_size, std::unique_ptr* cache); + // Hash inference request for cache access and store it in "request" object. + // This will also be called internally in Lookup/Insert if the request hasn't + // already stored it's hash. It is up to the user to update the hash in the + // request if modifying any hashed fields of the request object after storing. + // Return Status object indicating success or failure. + Status HashAndSet(InferenceRequest* const request); + + // Lookup 'request' hash in cache and return the inference response in + // 'response' on cache hit or nullptr on cache miss + // Return Status object indicating success or failure. + Status Lookup( + InferenceResponse* const response, InferenceRequest* const request); + // Insert response into cache, evict entries to make space if necessary + // Return Status object indicating success or failure. + Status Insert( + const InferenceResponse& response, InferenceRequest* const request); + // Evict entry from cache based on policy + // Return Status object indicating success or failure. + Status Evict(); + // Returns number of items in cache + size_t NumEntries() + { + std::lock_guard lk(cache_mtx_); + return cache_.size(); + } + // Returns number of items evicted in cache lifespan + size_t NumEvictions() + { + std::lock_guard lk(cache_mtx_); + return num_evictions_; + } + // Returns number of lookups in cache lifespan, should sum to hits + misses + size_t NumLookups() + { + std::lock_guard lk(cache_mtx_); + return num_lookups_; + } + // Returns number of cache hits in cache lifespan + size_t NumHits() + { + std::lock_guard lk(cache_mtx_); + return num_hits_; + } + // Returns number of cache hits in cache lifespan + size_t NumMisses() + { + std::lock_guard lk(cache_mtx_); + return num_misses_; + } + // Returns the total lookup latency (nanoseconds) of all lookups in cache + // lifespan + uint64_t TotalLookupLatencyNs() + { + std::lock_guard lk(cache_mtx_); + return total_lookup_latency_ns_; + } + + uint64_t TotalInsertionLatencyNs() + { + std::lock_guard lk(cache_mtx_); + return total_insertion_latency_ns_; + } + + // Returns total number of bytes allocated for cache + size_t TotalBytes() + { + std::lock_guard lk(buffer_mtx_); + return managed_buffer_.get_size(); + } + // Returns number of free bytes in cache + size_t FreeBytes() + { + std::lock_guard lk(buffer_mtx_); + return managed_buffer_.get_free_memory(); + } + // Returns number of bytes in use by cache + size_t AllocatedBytes() + { + std::lock_guard lk(buffer_mtx_); + return managed_buffer_.get_size() - managed_buffer_.get_free_memory(); + } + // Returns fraction of bytes allocated over total cache size between [0, 1] + double TotalUtilization() + { + std::lock_guard lk(buffer_mtx_); + return static_cast(AllocatedBytes()) / + static_cast(TotalBytes()); + } + + private: + explicit RequestResponseCache(const uint64_t cache_size); + // Update LRU ordering on lookup + void UpdateLRU(std::unordered_map::iterator&); + // Build CacheEntry from InferenceResponse + Status BuildCacheEntry( + const InferenceResponse& response, CacheEntry* const entry); + // Build InferenceResponse from CacheEntry + Status BuildInferenceResponse( + const CacheEntry& entry, InferenceResponse* const response); + // Helper function to hash data buffers used by "input" + Status HashInputBuffers(const InferenceRequest::Input* input, size_t* seed); + // Helper function to hash each input in "request" + Status HashInputs(const InferenceRequest& request, size_t* seed); + // Helper function to hash request and store it in "key" + Status Hash(const InferenceRequest& request, uint64_t* key); + + // Cache buffer + void* buffer_; + // Managed buffer + boost::interprocess::managed_external_buffer managed_buffer_; + // key -> CacheEntry containing values and list iterator for LRU management + std::unordered_map cache_; + // List of keys sorted from most to least recently used + std::list lru_; + // Cache metrics + size_t num_evictions_ = 0; + size_t num_lookups_ = 0; + size_t num_hits_ = 0; + size_t num_misses_ = 0; + uint64_t total_lookup_latency_ns_ = 0; + uint64_t total_insertion_latency_ns_ = 0; + // Mutex for buffer synchronization + std::recursive_mutex buffer_mtx_; + // Mutex for cache synchronization + std::recursive_mutex cache_mtx_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/scheduler.h b/3rdparty/core-r22.12/src/scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..7cc9142c8a7fc4a1cf86c352984493df7e1b3923 --- /dev/null +++ b/3rdparty/core-r22.12/src/scheduler.h @@ -0,0 +1,80 @@ +// Copyright (c) 2018-2020, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include "infer_request.h" +#include "status.h" + +namespace triton { namespace core { + +// Scheduler interface. +class Scheduler { + public: + virtual ~Scheduler() {} + + // The prototype for the initialization function that will be called + // by the "standard" schedulers created based on a model's + // scheduling_choice settings. The init function is called once by + // the runner that will later execute requests for 'runner_idx'. A + // non-OK error status indicates an initialization error that + // prevents scheduler from using the runner. + using StandardInitFunc = std::function; + + // The prototype for the warmup function that will be called by the + // "standard" schedulers created based on a model's + // scheduling_choice settings. The warmup function is called once by + // the runner that will later execute requests for 'runner_idx'. A + // non-OK error status indicates an error that prevents scheduler + // from sending warmup requests to the runner. + using StandardWarmupFunc = std::function; + + // The prototype for the run function that will be called by the + // "standard" schedulers created based on a model's + // scheduling_choice settings. The run function must accept a + // 'runner_idx' indicating which runner should execute the + // 'requests'. Ownership of the 'requests' is transferred to the + // runner which is responsible for generating responses and + // releasing the requests. + using StandardRunFunc = std::function>&& requests)>; + + // Enqueue a request with the scheduler. If Status::Success is returned + // then the backend has taken ownership of the request object and so + // 'request' will be nullptr. If non-success is returned then the + // caller still retains ownership of 'request'. + virtual Status Enqueue(std::unique_ptr& request) = 0; + + // Return the number of in-flight inferences tracked by the scheduler. + virtual size_t InflightInferenceCount() = 0; + + // Instruct the scheduler to stop processing future requests unless they are + // considered as in-flight. + virtual void Stop() = 0; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/scheduler_utils.cc b/3rdparty/core-r22.12/src/scheduler_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..f3a7e243744988a3e01db7ccffd822e19d214983 --- /dev/null +++ b/3rdparty/core-r22.12/src/scheduler_utils.cc @@ -0,0 +1,423 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "scheduler_utils.h" + +#include +#include "constants.h" +#include "triton/common/logging.h" + +namespace triton { namespace core { + +Status +RequiredEqualInputs::Initialize( + const std::unique_ptr& request, + const std::unordered_map& enforce_equal_shape_tensors, + const bool has_optional_input) +{ + has_optional_input_ = has_optional_input; + required_inputs_.clear(); + + for (const auto& pr : request->ImmutableInputs()) { + const InferenceRequest::Input* input = pr.second; + const auto itr = enforce_equal_shape_tensors.find(input->Name()); + if (itr != enforce_equal_shape_tensors.end()) { + required_inputs_.emplace( + std::piecewise_construct, std::forward_as_tuple(input->Name()), + std::forward_as_tuple(input, itr->second)); + } + // When the model has optional inputs, overload 'required_inputs_' + // to track the inputs involved in the batch + else if (has_optional_input) { + required_inputs_.emplace( + std::piecewise_construct, std::forward_as_tuple(input->Name()), + std::forward_as_tuple(nullptr, false)); + } + } + + init_ = true; + return Status::Success; +} + +bool +RequiredEqualInputs::HasEqualInputs( + const std::unique_ptr& request) +{ + // If current request has different number of inputs, then dynamic batching + // shouldn't be applied. + if (has_optional_input_ && + (request->ImmutableInputs().size() != required_inputs_.size())) { + return false; + } + for (const auto& pr : request->ImmutableInputs()) { + const InferenceRequest::Input* input = pr.second; + const auto itr = required_inputs_.find(input->Name()); + if (itr != required_inputs_.end()) { + if (itr->second.first != nullptr) { + // Make sure shape of input tensors is equal. + if (!triton::common::CompareDims( + itr->second.first->Shape(), input->Shape())) { + return false; + } + + // If necessary compare the contents as well... + if (itr->second.second) { + const auto& d1 = itr->second.first->Data(); + const auto& d2 = input->Data(); + + // For now being conservative and assuming that content + // comparison is for shape tensors which are likely to always + // be in a single buffer. + if ((d1->BufferCount() != 1) || (d2->BufferCount() != 1)) { + return false; + } + + size_t d1_byte_size, d2_byte_size; + TRITONSERVER_MemoryType d1_memory_type, d2_memory_type; + int64_t d1_memory_id, d2_memory_id; + const char* d1_buffer = d1->BufferAt( + 0 /* idx */, &d1_byte_size, &d1_memory_type, &d1_memory_id); + const char* d2_buffer = d2->BufferAt( + 0 /* idx */, &d2_byte_size, &d2_memory_type, &d2_memory_id); + + // Tensor must be same size and in in CPU memory so that it + // can be easily compared. If not return false conservatively. + if ((d1_byte_size != d2_byte_size) || (d1_buffer == nullptr) || + (d2_buffer == nullptr) || + (d1_memory_type == TRITONSERVER_MEMORY_GPU) || + (d2_memory_type == TRITONSERVER_MEMORY_GPU)) { + return false; + } + + if (strncmp(d1_buffer, d2_buffer, d1_byte_size) != 0) { + return false; + } + } + } + } else if (has_optional_input_) { + // If the model has optional inputs, the current request must contains all + // inputs that in the first request (tracked in 'required_inputs_'). + return false; + } + } + + return true; +} + +Status +PriorityQueue::PolicyQueue::Enqueue(std::unique_ptr& request) +{ + if ((max_queue_size_ != 0) && (Size() >= max_queue_size_)) { + return Status( + Status::Code::UNAVAILABLE, + request->LogRequest() + "Exceeds maximum queue size"); + } + + queue_.emplace_back(std::move(request)); + auto timeout_us = default_timeout_us_; + if (allow_timeout_override_) { + auto override_timeout_us = queue_.back()->TimeoutMicroseconds(); + if (override_timeout_us != 0 && override_timeout_us < timeout_us) { + timeout_us = override_timeout_us; + } + } + if (timeout_us != 0) { + timeout_timestamp_ns_.emplace_back( + std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count() + + timeout_us * 1000); + } else { + timeout_timestamp_ns_.emplace_back(0); + } + + return Status::Success; +} + +Status +PriorityQueue::PolicyQueue::Dequeue(std::unique_ptr* request) +{ + if (!queue_.empty()) { + *request = std::move(queue_.front()); + queue_.pop_front(); + timeout_timestamp_ns_.pop_front(); + } else { + *request = std::move(delayed_queue_.front()); + delayed_queue_.pop_front(); + } + + return Status::Success; +} + +bool +PriorityQueue::PolicyQueue::ApplyPolicy( + size_t idx, size_t* rejected_count, size_t* rejected_batch_size) +{ + uint64_t now_nanoseconds = + std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + if (idx < queue_.size()) { + size_t curr_idx = idx; + while (curr_idx < queue_.size()) { + if ((timeout_timestamp_ns_[curr_idx] != 0) && + (now_nanoseconds > timeout_timestamp_ns_[curr_idx])) { + if (timeout_action_ == inference::ModelQueuePolicy::DELAY) { + delayed_queue_.emplace_back(std::move(queue_[curr_idx])); + } else { + rejected_queue_.emplace_back(std::move(queue_[curr_idx])); + *rejected_count += 1; + *rejected_batch_size += + std::max(1U, rejected_queue_.back()->BatchSize()); + } + curr_idx++; + } else { + break; + } + } + + // Use range erasure on deque as all erasure functions are linear, + // this implies in the edge case where this function is always called on + // 'bad' index can be O(n^2). However, for data structures that are O(1) + // erasure, the traversal may not be as efficient due to cache miss + // (elements not stored contiguously). + queue_.erase(queue_.begin() + idx, queue_.begin() + curr_idx); + timeout_timestamp_ns_.erase( + timeout_timestamp_ns_.begin() + idx, + timeout_timestamp_ns_.begin() + curr_idx); + + // Current idx is pointing to an item with unexpired timeout + if (idx < queue_.size()) { + return true; + } + } + // At this point, idx is pointing to an item with expired timeout. + // If the item is in delayed queue, then return true. Otherwise, false + // meaning the queue has no item with this 'idx'. + return ((idx - queue_.size()) < delayed_queue_.size()); +} + +void +PriorityQueue::PolicyQueue::ReleaseRejectedQueue( + std::deque>* requests) +{ + rejected_queue_.swap(*requests); +} + +const std::unique_ptr& +PriorityQueue::PolicyQueue::At(size_t idx) const +{ + if (idx < queue_.size()) { + return queue_[idx]; + } else { + return delayed_queue_[idx - queue_.size()]; + } +} + +uint64_t +PriorityQueue::PolicyQueue::TimeoutAt(size_t idx) +{ + if (idx < queue_.size()) { + return timeout_timestamp_ns_[idx]; + } else { + return 0; + } +} + +PriorityQueue::PriorityQueue() + : size_(0), front_priority_level_(0), last_priority_level_(0) +{ + inference::ModelQueuePolicy default_policy; + queues_.emplace(0, PolicyQueue(default_policy)); + front_priority_level_ = queues_.begin()->first; + ResetCursor(); +} + +PriorityQueue::PriorityQueue( + const inference::ModelQueuePolicy& default_queue_policy, + uint32_t priority_levels, const ModelQueuePolicyMap queue_policy_map) + : size_(0), last_priority_level_(priority_levels) +{ + if (priority_levels == 0) { + queues_.emplace(0, PolicyQueue(default_queue_policy)); + } else { + for (uint32_t level = 1; level <= priority_levels; level++) { + auto it = queue_policy_map.find(level); + if (it == queue_policy_map.end()) { + queues_.emplace(level, PolicyQueue(default_queue_policy)); + } else { + queues_.emplace(level, PolicyQueue(it->second)); + } + } + } + front_priority_level_ = queues_.begin()->first; + ResetCursor(); +} + +Status +PriorityQueue::Enqueue( + uint32_t priority_level, std::unique_ptr& request) +{ + auto status = queues_[priority_level].Enqueue(request); + if (status.IsOk()) { + size_++; + front_priority_level_ = std::min(front_priority_level_, priority_level); + // Invalidate the pending batch cursor if the enqueued item is placed + // within the pending batch. At the same priority level the request is + // guaranteed to be after pending batch if the batch hasn't reached + // delayed queue. + if ((priority_level < pending_cursor_.curr_it_->first) || + ((priority_level == pending_cursor_.curr_it_->first) && + (pending_cursor_.at_delayed_queue_))) { + pending_cursor_.valid_ = false; + } + } + + return status; +} + +Status +PriorityQueue::Dequeue(std::unique_ptr* request) +{ + pending_cursor_.valid_ = false; + while (true) { + if (!queues_[front_priority_level_].Empty()) { + RETURN_IF_ERROR(queues_[front_priority_level_].Dequeue(request)); + size_--; + return Status::Success; + } else if (front_priority_level_ != last_priority_level_) { + front_priority_level_++; + continue; + } + + // Control reach here if the queue for last priority level is also + // empty, then return error below. + break; + } + + return Status( + Status::Code::UNAVAILABLE, + (*request)->LogRequest() + "dequeue on empty queue"); +} + +void +PriorityQueue::ReleaseRejectedRequests( + std::shared_ptr>>>* + requests) +{ + auto res = std::make_shared< + std::vector>>>( + queues_.size()); + size_t idx = 0; + for (auto& queue : queues_) { + queue.second.ReleaseRejectedQueue(&((*res)[idx])); + idx++; + } + + requests->swap(res); +} + +bool +PriorityQueue::IsCursorValid() +{ + if (pending_cursor_.valid_) { + return (uint64_t)std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count() < pending_cursor_.pending_batch_closest_timeout_ns_; + } + return false; +} + +PriorityQueue::Cursor::Cursor(PriorityQueues::iterator start_it) + : curr_it_(start_it), queue_idx_(0), at_delayed_queue_(false), + pending_batch_closest_timeout_ns_(0), + pending_batch_oldest_enqueue_time_ns_(0), pending_batch_count_(0), + valid_(true) +{ +} + +size_t +PriorityQueue::ApplyPolicyAtCursor() +{ + size_t rejected_batch_size = 0; + size_t rejected_count = 0; + while (pending_cursor_.curr_it_ != queues_.end()) { + if (!(pending_cursor_.curr_it_->second.ApplyPolicy( + pending_cursor_.queue_idx_, &rejected_count, + &rejected_batch_size))) { + if (size_ > pending_cursor_.pending_batch_count_ + rejected_count) { + pending_cursor_.curr_it_++; + pending_cursor_.queue_idx_ = 0; + continue; + } + } + // Control reach here if the cursor points to a request that is candidate + // for pending batch, or if all requests are in pending batch. + break; + } + size_ -= rejected_count; + return rejected_batch_size; +} + +void +PriorityQueue::AdvanceCursor() +{ + if (pending_cursor_.pending_batch_count_ >= size_) { + return; + } + + const auto& timeout_ns = + pending_cursor_.curr_it_->second.TimeoutAt(pending_cursor_.queue_idx_); + if (timeout_ns != 0) { + if (pending_cursor_.pending_batch_closest_timeout_ns_ != 0) { + pending_cursor_.pending_batch_closest_timeout_ns_ = std::min( + pending_cursor_.pending_batch_closest_timeout_ns_, timeout_ns); + } else { + pending_cursor_.pending_batch_closest_timeout_ns_ = timeout_ns; + } + } + + uint64_t curr_enqueue_time_ns = + pending_cursor_.curr_it_->second.At(pending_cursor_.queue_idx_) + ->BatcherStartNs(); + if (pending_cursor_.pending_batch_oldest_enqueue_time_ns_ != 0) { + pending_cursor_.pending_batch_oldest_enqueue_time_ns_ = std::min( + pending_cursor_.pending_batch_oldest_enqueue_time_ns_, + curr_enqueue_time_ns); + } else { + pending_cursor_.pending_batch_oldest_enqueue_time_ns_ = + curr_enqueue_time_ns; + } + ++pending_cursor_.queue_idx_; + ++pending_cursor_.pending_batch_count_; + // pending batch includes delayed request if (queue_idx_ - 1) points to + // delayed queue. + pending_cursor_.at_delayed_queue_ = + (pending_cursor_.queue_idx_ > + pending_cursor_.curr_it_->second.UnexpiredSize()); +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/scheduler_utils.h b/3rdparty/core-r22.12/src/scheduler_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..1790f369c24281588a30955807abdad351ca2092 --- /dev/null +++ b/3rdparty/core-r22.12/src/scheduler_utils.h @@ -0,0 +1,256 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include "scheduler.h" + +namespace triton { namespace core { + +struct RequiredEqualInputs { + public: + RequiredEqualInputs() : init_(false), has_optional_input_(false) {} + Status Initialize( + const std::unique_ptr& request, + const std::unordered_map& enforce_equal_shape_tensors, + const bool has_optional_input); + bool HasEqualInputs(const std::unique_ptr& request); + bool Initialized() { return init_; }; + + private: + bool init_; + bool has_optional_input_; + // A collection of inputs in the request, an nullptr for + // InferenceRequest::Input indicates that the inputs doesn't require + // equality check + std::unordered_map< + std::string, + std::pair> + required_inputs_; +}; + +// +// PriorityQueue +// +using ModelQueuePolicyMap = ::google::protobuf::Map< + ::google::protobuf::uint32, inference::ModelQueuePolicy>; + +class PriorityQueue { + public: + // Construct a queue with no priority level with default queue policy, + // which will behave the same as regular queue. + PriorityQueue(); + + // Construct a queue with 'priority_levels', the priority starts from 1. + // Different priority level may follow different queue policies given by + // 'queue_policy_map', otherwise, the 'default_queue_policy' will be used. + PriorityQueue( + const inference::ModelQueuePolicy& default_queue_policy, + uint32_t priority_levels, const ModelQueuePolicyMap queue_policy_map); + + // Enqueue a request with priority set to 'priority_level'. If + // Status::Success is returned then the queue has taken ownership of + // the request object and so 'request' will be nullptr. If + // non-success is returned then the caller still retains ownership + // of 'request'. + Status Enqueue( + uint32_t priority_level, std::unique_ptr& request); + + // Dequeue the request at the front of the queue. + Status Dequeue(std::unique_ptr* request); + + // Retrieve the requests that are rejected based on the queue policies. + void ReleaseRejectedRequests( + std::shared_ptr< + std::vector>>>* + requests); + + // Return the number of requests in the queue, rejected requests are + // not included. + size_t Size() { return size_; } + + // Is the queue is empty? Rejected requests are not included. + bool Empty() { return Size() == 0; } + + // Reset the cursor such that it is representing an empty pending batch. + void ResetCursor() { pending_cursor_ = Cursor(queues_.begin()); } + + // Record the current cursor. The cursor can be restored to recorded state + // by invoking SetCursorToMark(). Note that Enqueue(), Dequeue(), and + // ResetCursor() will invalidate the marker, it is the function caller's + // responsibility to ensure the marker is valid before calling + // SetCursorToMark(). + void MarkCursor() { current_mark_ = pending_cursor_; } + + // Apply the queue policy and alter the underlying queue accordingly. After + // the function returns, the cursor may be at its end to indicate that + // there no request after the pending batch. + // Returns the total batch size of the newly rejected requests. + size_t ApplyPolicyAtCursor(); + + // Return the request at the cursor. + const std::unique_ptr& RequestAtCursor() + { + return pending_cursor_.curr_it_->second.At(pending_cursor_.queue_idx_); + } + + // Advance the cursor for pending batch. This function will not trigger the + // queue policy. No effect if the cursor already reach the end of the queue. + void AdvanceCursor(); + + // Whether the cursor reaches its end, + bool CursorEnd() { return pending_cursor_.pending_batch_count_ == size_; } + + // Restore the cursor state to the marker. + void SetCursorToMark() { pending_cursor_ = current_mark_; } + + // Whether the cursor is still valid. The cursor is valid only if the pending + // batch is unchanged. + bool IsCursorValid(); + + // Return the oldest queued time of requests in pending batch. + uint64_t OldestEnqueueTime() + { + return pending_cursor_.pending_batch_oldest_enqueue_time_ns_; + } + + // Return the closest timeout of requests in pending batch. + uint64_t ClosestTimeout() + { + return pending_cursor_.pending_batch_closest_timeout_ns_; + } + + // Return the number of requests in pending batch. + size_t PendingBatchCount() { return pending_cursor_.pending_batch_count_; } + + private: + class PolicyQueue { + public: + // Construct a policy queue with default policy, which will behave the same + // as regular queue. + PolicyQueue() + : timeout_action_(inference::ModelQueuePolicy::REJECT), + default_timeout_us_(0), allow_timeout_override_(false), + max_queue_size_(0) + { + } + + // Construct a policy queue with given 'policy'. + PolicyQueue(const inference::ModelQueuePolicy& policy) + : timeout_action_(policy.timeout_action()), + default_timeout_us_(policy.default_timeout_microseconds()), + allow_timeout_override_(policy.allow_timeout_override()), + max_queue_size_(policy.max_queue_size()) + { + } + + // Enqueue a request and set up its timeout accordingly. If + // Status::Success is returned then the queue has taken ownership + // of the request object and so 'request' will be nullptr. If + // non-success is returned then the caller still retains ownership + // of 'request'. + Status Enqueue(std::unique_ptr& request); + + // Dequeue the request at the front of the queue. + Status Dequeue(std::unique_ptr* request); + + // Apply the queue policy to the request at 'idx'. + // 'rejected_count' will be incremented by the number of the newly rejected + // requets after applying the policy. + // 'rejected_batch_size' will be incremented by the total batch size of the + // newly rejected requests after applying the policy. + // Return true if the 'idx' still points to a request after applying the + // policy, false otherwise. + bool ApplyPolicy( + size_t idx, size_t* rejected_count, size_t* rejected_batch_size); + + // Return the rejected requests held by the queue. + void ReleaseRejectedQueue( + std::deque>* requests); + + // Return the request at 'idx'. + const std::unique_ptr& At(size_t idx) const; + + // Return the timeout timestamp of the request at 'idx', in ns. A value of 0 + // indicates that the request doesn't specify a timeout. + uint64_t TimeoutAt(size_t idx); + + // Return whether the queue is empty, rejected requests are not included. + bool Empty() { return Size() == 0; } + + // Return the number of requests in the queue, rejected requests are not + // included. + size_t Size() { return queue_.size() + delayed_queue_.size(); } + + // Return the number of unexpired requests in the queue + size_t UnexpiredSize() { return queue_.size(); } + + private: + // Variables that define the policy for the queue + const inference::ModelQueuePolicy::TimeoutAction timeout_action_; + const uint64_t default_timeout_us_; + const bool allow_timeout_override_; + const uint32_t max_queue_size_; + + std::deque timeout_timestamp_ns_; + std::deque> queue_; + std::deque> delayed_queue_; + std::deque> rejected_queue_; + }; + using PriorityQueues = std::map; + + // Cursor for tracking pending batch, the cursor points to the item after + // the pending batch. + struct Cursor { + Cursor() = default; + Cursor(PriorityQueues::iterator start_it); + + Cursor(const Cursor& rhs) = default; + Cursor& operator=(const Cursor& rhs) = default; + + PriorityQueues::iterator curr_it_; + size_t queue_idx_; + bool at_delayed_queue_; + uint64_t pending_batch_closest_timeout_ns_; + uint64_t pending_batch_oldest_enqueue_time_ns_; + size_t pending_batch_count_; + bool valid_; + }; + + PriorityQueues queues_; + size_t size_; + + // Keep track of the priority level that the first request in the queue + // is at to avoid traversing 'queues_' + uint32_t front_priority_level_; + uint32_t last_priority_level_; + + Cursor pending_cursor_; + Cursor current_mark_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/sequence_batch_scheduler.cc b/3rdparty/core-r22.12/src/sequence_batch_scheduler.cc new file mode 100644 index 0000000000000000000000000000000000000000..67f9ded8b8c1f2c3b54f38e21b12ad4898227a84 --- /dev/null +++ b/3rdparty/core-r22.12/src/sequence_batch_scheduler.cc @@ -0,0 +1,1687 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "sequence_batch_scheduler.h" + +#ifndef _WIN32 +#include +#include +#include +#endif +#include +#include "constants.h" +#include "dynamic_batch_scheduler.h" +#include "model_config_utils.h" +#include "server.h" +#include "triton/common/logging.h" + +namespace triton { namespace core { + +Status +SequenceBatchScheduler::Create( + TritonModel* model, + const std::unordered_map& enforce_equal_shape_tensors, + std::unique_ptr* scheduler) +{ + std::unique_ptr sched(new SequenceBatchScheduler()); + + // For debugging and testing, + const char* dstr = getenv("TRITONSERVER_BACKLOG_DELAY_SCHEDULER"); + sched->backlog_delay_cnt_ = 0; + if (dstr != nullptr) { + sched->backlog_delay_cnt_ = atoi(dstr); + LOG_INFO << "Delaying scheduler until " << sched->backlog_delay_cnt_ + << " backlog queued requests..."; + } + + auto instance_count = model->Instances().size(); + sched->queue_request_cnts_.resize(instance_count, 0); + + auto& config = model->Config(); + + // Max sequence idle... + sched->max_sequence_idle_microseconds_ = + config.sequence_batching().max_sequence_idle_microseconds(); + + sched->max_batch_size_ = config.max_batch_size(); + + // Implicit States + auto& states = config.sequence_batching().state(); + + for (const inference::ModelSequenceBatching_State& state : states) { + sched->state_output_config_map_.insert({state.output_name(), state}); + + if (state.initial_state_size() > 1) { + return Status( + Status::Code::INVALID_ARG, + std::string( + std::string("initial_state field for state input '") + + state.input_name() + + "' must contain exactly one or zero element. Found '" + + std::to_string(state.initial_state_size()) + "' elements.")); + } + + // If the model configuration has initial_state field. + if (state.initial_state_size() == 1) { + auto& initial_state = state.initial_state(0); + RETURN_IF_ERROR( + sched->GenerateInitialStateData(initial_state, state, model)); + } + } + + // Get the number of candidate sequence slots to allow for each + // runner. This is at least 1 even if the model doesn't support + // batching. + const size_t model_batch_size = std::max(1, config.max_batch_size()); + size_t seq_slot_cnt = model_batch_size; + if (config.sequence_batching().has_oldest()) { + seq_slot_cnt = + config.sequence_batching().oldest().max_candidate_sequences(); + } + + // Based on the model configuration create input tensors for control + // signals indicating sequence start, sequence continue, and + // sequence not ready. + std::shared_ptr start; + std::shared_ptr end; + std::shared_ptr startend; + std::shared_ptr cont; + std::shared_ptr notready; + RETURN_IF_ERROR(sched->CreateBooleanControlTensors( + config, &start, &end, &startend, &cont, ¬ready)); + + bool has_optional_input = false; + for (const auto& input : config.input()) { + if (input.optional()) { + has_optional_input = true; + break; + } + } + + // Create one SequenceBatch object for each requested runner. The + // SequenceBatch object has a thread that manages the batch of + // requests. + const auto& instances = model->Instances(); + uint32_t index = 0; + for (const auto& instance : instances) { + bool init_state; + std::unique_ptr sb; + + // Create the SequenceBatch derivative that handles the requested + // scheduling strategy. + if (config.sequence_batching().has_oldest()) { + sb.reset(new OldestSequenceBatch( + sched.get(), index, seq_slot_cnt, instance.get(), + enforce_equal_shape_tensors, has_optional_input, start, end, startend, + cont, notready, &init_state)); + } else { + sb.reset(new DirectSequenceBatch( + sched.get(), index, seq_slot_cnt, instance.get(), + enforce_equal_shape_tensors, has_optional_input, start, end, startend, + cont, notready, &init_state)); + } + + if (init_state) { + sched->batchers_.push_back(std::move(sb)); + // All sequence slots in the batcher are initially ready for a + // new sequence. + for (size_t b = 0; b < seq_slot_cnt; ++b) { + sched->ready_batcher_seq_slots_.push( + SequenceBatchScheduler::BatcherSequenceSlot(index, b)); + } + } + ++index; + } + if (sched->batchers_.empty()) { + return Status( + Status::Code::INTERNAL, + "Initialization failed for all sequence-batch scheduler threads"); + } + + // Create a reaper thread that watches for idle sequences. Run the + // reaper a lower priority. + SequenceBatchScheduler* raw = sched.release(); + + raw->reaper_thread_exit_ = false; + raw->reaper_thread_.reset( + new std::thread([raw]() { raw->ReaperThread(10 /* nice */); })); + + scheduler->reset(raw); + + return Status::Success; +} + +Status +SequenceBatchScheduler::GenerateInitialStateData( + const inference::ModelSequenceBatching_InitialState& initial_state, + const inference::ModelSequenceBatching_State& state, TritonModel* model) +{ + if (initial_state.data_type() != state.data_type()) { + return Status( + Status::Code::INVALID_ARG, + std::string("The data type used for 'initial_state' field of state '") + + state.input_name() + "' does not match the state data type."); + } + + if (initial_state.name().size() == 0) { + return Status( + Status::Code::INVALID_ARG, + std::string("Field 'name' must be set when using initial_state for " + "state input '") + + state.input_name() + "'."); + } + + auto initial_state_itr = initial_state_.find(state.input_name()); + if (initial_state_itr != initial_state_.end()) { + return Status( + Status::Code::INVALID_ARG, std::string("State input name '") + + state.input_name() + + "' specified more than once."); + } + + if (initial_state.dims().size() != state.dims().size()) { + return Status( + Status::Code::INVALID_ARG, + std::string( + "Number of dimensions in 'initial_state' doesn't match the size of" + " 'state' dimensions for state input '") + + state.input_name() + "'. " + + std::to_string(initial_state.dims().size()) + + " != " + std::to_string(state.dims().size())); + } + + // Check the dimensions to make sure it doesn't have variable-sized dims and + // matches the state description. + auto initial_state_dim = initial_state.dims().begin(); + auto state_dim = state.dims().begin(); + for (; initial_state_dim != initial_state.dims().end(); + initial_state_dim++, state_dim++) { + if (*initial_state_dim == -1) { + return Status( + Status::Code::INVALID_ARG, + std::string("'initial_state' field for state input name '") + + state.input_name() + "' contains variable dimensions."); + } else { + if (*state_dim != -1 && *initial_state_dim != *state_dim) { + return Status( + Status::Code::INVALID_ARG, + std::string("'initial_state' dim for input name '") + + state.input_name() + + "' doesn't match 'state' dim description. " + + std::to_string(*initial_state_dim) + + " != " + std::to_string(*state_dim)); + } + } + } + + const auto& initial_state_pair = initial_state_.emplace( + std::piecewise_construct, std::forward_as_tuple(state.input_name()), + std::forward_as_tuple(initial_state.name())); + auto& initial_state_data = initial_state_pair.first->second; + + // Calculate total memory byte size + auto element_count = triton::common::GetElementCount(initial_state.dims()); + size_t dtype_byte_size = + triton::common::GetDataTypeByteSize(initial_state.data_type()); + size_t total_byte_size = element_count * dtype_byte_size; + + // Custom handling for TYPE_BYTES + if (dtype_byte_size == 0) { + total_byte_size = sizeof(int32_t) * element_count; + } + + switch (initial_state.state_data_case()) { + case inference::ModelSequenceBatching_InitialState::StateDataCase:: + kZeroData: { + initial_state_data.data_ = std::make_shared( + total_byte_size, TRITONSERVER_MEMORY_CPU /* memory_type */, + 0 /* memory_type_id */); + + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id; + char* data_ptr = initial_state_data.data_->MutableBuffer( + &memory_type, &memory_type_id); + memset(data_ptr, 0, total_byte_size); + break; + } + case inference::ModelSequenceBatching_InitialState::StateDataCase:: + kDataFile: { + std::string file_input; + RETURN_IF_ERROR(ReadTextFile( + JoinPath({model->LocalizedModelPath(), kInitialStateFolder, + (initial_state.data_file())}), + &file_input)); + if (initial_state.data_type() == inference::DataType::TYPE_STRING) { + total_byte_size = file_input.size(); + } else if (total_byte_size > file_input.size()) { + return Status( + Status::Code::INVALID_ARG, + "initial_state setting expects " + std::to_string(total_byte_size) + + " bytes, but the data " + "provided from " + + initial_state.data_file() + "only has " + + std::to_string(file_input.size()) + " bytes."); + } + + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id; + + initial_state_data.data_ = std::make_shared( + total_byte_size, TRITONSERVER_MEMORY_CPU /* memory_type */, + 0 /* memory_type_id */); + char* data_ptr = initial_state_data.data_->MutableBuffer( + &memory_type, &memory_type_id); + memcpy(data_ptr, file_input.data(), total_byte_size); + + break; + } + default: + return Status( + Status::Code::INVALID_ARG, + std::string("initial_state setting expects state'") + + state.input_name() + "' to have state_data set"); + } + + return Status::Success; +} + +SequenceBatchScheduler::~SequenceBatchScheduler() +{ + // Signal the reaper thread to exit... + { + std::unique_lock lock(mu_); + reaper_thread_exit_ = true; + } + + reaper_cv_.notify_one(); + if ((reaper_thread_ != nullptr) && reaper_thread_->joinable()) { + reaper_thread_->join(); + } + + // Release 'batchers_' before other member variables because 'batchers_' + // can access 'this' and we need to make sure the member variables live + // longer than 'batchers_' + batchers_.clear(); +} + + +namespace { + +Status +GetBooleanOverrideInputs( + const std::string& tensor_name, const bool support_batching, + const inference::DataType tensor_datatype, const float fp32_false_value, + const float fp32_true_value, const int32_t int32_false_value, + const int32_t int32_true_value, const bool bool_false_value, + const bool bool_true_value, + std::shared_ptr* true_override, + std::shared_ptr* false_override) +{ + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id; + + const std::vector tensor_shape{1}; + std::vector tensor_shape_with_batch_dim{1}; + if (support_batching) { + tensor_shape_with_batch_dim.push_back(1); + } + const size_t size_p = triton::common::GetDataTypeByteSize(tensor_datatype); + + auto true_p = + std::make_shared(size_p, TRITONSERVER_MEMORY_CPU, 0); + char* true_p_ptr = true_p->MutableBuffer(&memory_type, &memory_type_id); + if ((true_p_ptr == nullptr) || + ((memory_type != TRITONSERVER_MEMORY_CPU) && + (memory_type != TRITONSERVER_MEMORY_CPU_PINNED)) || + (memory_type_id != 0)) { + return Status( + Status::Code::INTERNAL, + "failed to allocate sequence control signal in CPU memory"); + } + + auto false_p = + std::make_shared(size_p, TRITONSERVER_MEMORY_CPU, 0); + char* false_p_ptr = false_p->MutableBuffer(&memory_type, &memory_type_id); + if ((false_p_ptr == nullptr) || + ((memory_type != TRITONSERVER_MEMORY_CPU) && + (memory_type != TRITONSERVER_MEMORY_CPU_PINNED)) || + (memory_type_id != 0)) { + return Status( + Status::Code::INTERNAL, + "failed to allocate sequence control signal in CPU memory"); + } + + if (tensor_datatype == inference::DataType::TYPE_INT32) { + *(reinterpret_cast(true_p_ptr)) = int32_true_value; + *(reinterpret_cast(false_p_ptr)) = int32_false_value; + } else if (tensor_datatype == inference::DataType::TYPE_FP32) { + *(reinterpret_cast(true_p_ptr)) = fp32_true_value; + *(reinterpret_cast(false_p_ptr)) = fp32_false_value; + } else { + *(reinterpret_cast(true_p_ptr)) = bool_true_value; + *(reinterpret_cast(false_p_ptr)) = bool_false_value; + } + + auto ltrue_override = std::make_shared( + tensor_name, tensor_datatype, tensor_shape); + *ltrue_override->MutableShape() = ltrue_override->OriginalShape(); + *ltrue_override->MutableShapeWithBatchDim() = tensor_shape_with_batch_dim; + RETURN_IF_ERROR(ltrue_override->SetData(true_p)); + + auto lfalse_override = std::make_shared( + tensor_name, tensor_datatype, tensor_shape); + *lfalse_override->MutableShape() = lfalse_override->OriginalShape(); + *lfalse_override->MutableShapeWithBatchDim() = tensor_shape_with_batch_dim; + RETURN_IF_ERROR(lfalse_override->SetData(false_p)); + + *true_override = std::move(ltrue_override); + *false_override = std::move(lfalse_override); + + return Status::Success; +} + +} // namespace + +Status +SequenceBatchScheduler::CreateBooleanControlTensors( + const inference::ModelConfig& config, + std::shared_ptr* start_input_overrides, + std::shared_ptr* end_input_overrides, + std::shared_ptr* startend_input_overrides, + std::shared_ptr* continue_input_overrides, + std::shared_ptr* notready_input_overrides) +{ + // Currently only batch-size 1 requests are supported so only need + // to provide control vectors of that size. + *start_input_overrides = std::make_shared(); + *end_input_overrides = std::make_shared(); + *startend_input_overrides = std::make_shared(); + *continue_input_overrides = std::make_shared(); + *notready_input_overrides = std::make_shared(); + + std::string tensor_name; + inference::DataType tensor_datatype; + int32_t int32_false_value, int32_true_value; + float fp32_false_value, fp32_true_value; + bool bool_false_value, bool_true_value; + + // START, optional + { + RETURN_IF_ERROR(GetBooleanSequenceControlProperties( + config.sequence_batching(), config.name(), + inference::ModelSequenceBatching::Control::CONTROL_SEQUENCE_START, + false /* required */, &tensor_name, &tensor_datatype, &fp32_false_value, + &fp32_true_value, &int32_false_value, &int32_true_value, + &bool_false_value, &bool_true_value)); + if (!tensor_name.empty()) { + std::shared_ptr true_override; + std::shared_ptr false_override; + + RETURN_IF_ERROR(GetBooleanOverrideInputs( + tensor_name, config.max_batch_size() != 0, tensor_datatype, + fp32_false_value, fp32_true_value, int32_false_value, + int32_true_value, bool_false_value, bool_true_value, &true_override, + &false_override)); + + (*start_input_overrides)->emplace_back(true_override); + (*end_input_overrides)->emplace_back(false_override); + (*startend_input_overrides)->emplace_back(true_override); + (*continue_input_overrides)->emplace_back(false_override); + (*notready_input_overrides)->emplace_back(false_override); + } + } + + // END, optional + { + RETURN_IF_ERROR(GetBooleanSequenceControlProperties( + config.sequence_batching(), config.name(), + inference::ModelSequenceBatching::Control::CONTROL_SEQUENCE_END, + false /* required */, &tensor_name, &tensor_datatype, &fp32_false_value, + &fp32_true_value, &int32_false_value, &int32_true_value, + &bool_false_value, &bool_true_value)); + if (!tensor_name.empty()) { + std::shared_ptr true_override; + std::shared_ptr false_override; + + RETURN_IF_ERROR(GetBooleanOverrideInputs( + tensor_name, config.max_batch_size() != 0, tensor_datatype, + fp32_false_value, fp32_true_value, int32_false_value, + int32_true_value, bool_false_value, bool_true_value, &true_override, + &false_override)); + + (*start_input_overrides)->emplace_back(false_override); + (*end_input_overrides)->emplace_back(true_override); + (*startend_input_overrides)->emplace_back(true_override); + (*continue_input_overrides)->emplace_back(false_override); + (*notready_input_overrides)->emplace_back(false_override); + } + } + + // READY, optional + { + RETURN_IF_ERROR(GetBooleanSequenceControlProperties( + config.sequence_batching(), config.name(), + inference::ModelSequenceBatching::Control::CONTROL_SEQUENCE_READY, + false /* required */, &tensor_name, &tensor_datatype, &fp32_false_value, + &fp32_true_value, &int32_false_value, &int32_true_value, + &bool_false_value, &bool_true_value)); + if (!tensor_name.empty()) { + std::shared_ptr true_override; + std::shared_ptr false_override; + + RETURN_IF_ERROR(GetBooleanOverrideInputs( + tensor_name, config.max_batch_size() != 0, tensor_datatype, + fp32_false_value, fp32_true_value, int32_false_value, + int32_true_value, bool_false_value, bool_true_value, &true_override, + &false_override)); + + (*start_input_overrides)->emplace_back(true_override); + (*end_input_overrides)->emplace_back(true_override); + (*startend_input_overrides)->emplace_back(true_override); + (*continue_input_overrides)->emplace_back(true_override); + (*notready_input_overrides)->emplace_back(false_override); + } + } + + return Status::Success; +} + +Status +SequenceBatchScheduler::Enqueue(std::unique_ptr& irequest) +{ + // Queue timer starts at the beginning of the queueing and + // scheduling process + irequest->CaptureQueueStartNs(); + INFER_TRACE_ACTIVITY( + irequest->Trace(), TRITONSERVER_TRACE_QUEUE_START, + irequest->QueueStartNs()); + + // Record time at the beginning of the batcher queueing + irequest->CaptureBatcherStartNs(); + + // For now the request must have batch-size 1 since the sequence + // batcher does not yet support requests that are statically + // batched. + if (irequest->BatchSize() > 1) { + return Status( + Status::Code::INVALID_ARG, + "inference request to model '" + irequest->ModelName() + + "' must specify batch-size 1 due to requirements of sequence " + "batcher"); + } + + // A request must have a correlation ID to be processed correctly by + // this scheduler. A value of 0 (zero) or "" (empty) indicates that the + // request doesn't have a correlation ID. + const InferenceRequest::SequenceId& correlation_id = + irequest->CorrelationId(); + if (!correlation_id.InSequence()) { + return Status( + Status::Code::INVALID_ARG, + "inference request to model '" + irequest->ModelName() + + "' must specify a non-zero or non-empty correlation ID"); + } + + BatcherSequenceSlot* target = nullptr; + + const bool seq_start = + ((irequest->Flags() & TRITONSERVER_REQUEST_FLAG_SEQUENCE_START) != 0); + const bool seq_end = + ((irequest->Flags() & TRITONSERVER_REQUEST_FLAG_SEQUENCE_END) != 0); + + // Check if the request is one of the in-flight sequence (not starting new + // sequence), we consider sequences in backlog as also in-flight. + if (stop_ && seq_start) { + return Status( + Status::Code::UNAVAILABLE, + "Server is stopping, scheduler for model has stopped accepting new " + "inference requests"); + } + + std::unique_lock lock(mu_); + + auto sb_itr = sequence_to_batcherseqslot_map_.find(correlation_id); + auto bl_itr = sequence_to_backlog_map_.find(correlation_id); + + // If this request is not starting a new sequence its correlation ID + // should already be known with a target in either a sequence slot + // or in the backlog. If it doesn't then the sequence wasn't started + // correctly or there has been a correlation ID conflict. In either + // case fail this request. + if (!seq_start && (sb_itr == sequence_to_batcherseqslot_map_.end()) && + (bl_itr == sequence_to_backlog_map_.end())) { + std::string correlation_id_str{""}; + if (correlation_id.Type() == + InferenceRequest::SequenceId::DataType::STRING) { + correlation_id_str = correlation_id.StringValue(); + } else if ( + correlation_id.Type() == + InferenceRequest::SequenceId::DataType::UINT64) { + correlation_id_str = std::to_string(correlation_id.UnsignedIntValue()); + } + return Status( + Status::Code::INVALID_ARG, + "inference request for sequence " + correlation_id_str + " to model '" + + irequest->ModelName() + + "' must specify the START flag on the first request of the " + "sequence"); + } + + // Record the timestamp of this request for the correlation ID. The + // reaper thread will check to make sure that + // max_sequence_idle_microseconds value is not exceed for any + // sequence, and if it is it will release the sequence slot (if any) + // allocated to that sequence. + { + uint64_t now_us = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + correlation_id_timestamps_[correlation_id] = now_us; + } + + // If this request starts a new sequence but the correlation ID + // already has an in-progress sequence then that previous sequence + // did not end correctly, or there is a correlation ID conflict. In + // this case we continue the new sequence (in either backlog or + // sequence slot). It is ok for a backlog/slot to have multiple + // starts... as long as it has a single end. The previous sequence + // that was not correctly ended will have its existing requests + // handled and then the new sequence will start. + if (seq_start && ((sb_itr != sequence_to_batcherseqslot_map_.end()) || + (bl_itr != sequence_to_backlog_map_.end()))) { + LOG_WARNING + << "sequence " << correlation_id << " for model '" + << irequest->ModelName() + << "' has a conflict. The previous sequence did not end before this " + "sequence start. Previous sequence will be terminated early."; + } + + // This request already has an assigned slot... + if (sb_itr != sequence_to_batcherseqslot_map_.end()) { + target = &sb_itr->second; + } + // This request already has a queue in the backlog... + else if (bl_itr != sequence_to_backlog_map_.end()) { + LOG_VERBOSE(1) << "Enqueuing CORRID " << correlation_id + << " into existing backlog: " << irequest->ModelName(); + + bl_itr->second->emplace_back(std::move(irequest)); + + // If the sequence is ending then forget correlation ID + // connection to this backlog queue. If another sequence starts + // with the same correlation ID it will be collected in another + // backlog queue. + if (seq_end) { + sequence_to_backlog_map_.erase(bl_itr); + } + return Status::Success; + } + // This request does not have an assigned backlog or sequence + // slot. By the above checks it must be starting. If there is a free + // sequence slot available then assign this sequence to that slot... + else if (!ready_batcher_seq_slots_.empty()) { + target = &sequence_to_batcherseqslot_map_[correlation_id]; + *target = ready_batcher_seq_slots_.top(); + ready_batcher_seq_slots_.pop(); + } + // Last option is to assign this request to the backlog... + else { + LOG_VERBOSE(1) << "Enqueuing CORRID " << correlation_id + << " into new backlog: " << irequest->ModelName(); + + auto backlog = + std::make_shared>>(); + backlog_queues_.push_back(backlog); + backlog->emplace_back(std::move(irequest)); + if (!seq_end) { + sequence_to_backlog_map_[correlation_id] = std::move(backlog); + } + return Status::Success; + } + + // Need to grab the target contents before the erase below since + // that can free it. + const size_t batcher_idx = target->batcher_idx_; + const uint32_t seq_slot = target->seq_slot_; + + // At this point the request has been assigned to a sequence + // slot. If the sequence is ending then stop tracking the + // correlation. + if (seq_end) { + sequence_to_batcherseqslot_map_.erase(correlation_id); + } + + // Enqueue request into batcher and sequence slot. Don't hold the + // lock while enqueuing in a specific batcher. + lock.unlock(); + + LOG_VERBOSE(1) << "Enqueuing CORRID " << correlation_id << " into batcher " + << batcher_idx << ", sequence slot " << seq_slot << ": " + << irequest->ModelName(); + + batchers_[batcher_idx]->Enqueue(seq_slot, correlation_id, irequest); + + return Status::Success; +} + +InferenceRequest::SequenceId +SequenceBatchScheduler::ReleaseSequenceSlot( + const BatcherSequenceSlot& batcher_seq_slot, + std::deque>* requests) +{ + std::unique_lock lock(mu_); + + // If there is a backlogged sequence and it is requested, return it + // so that it can use the newly available sequence slot. + if (!backlog_queues_.empty()) { + auto& backlog = backlog_queues_.front(); + *requests = std::move(*backlog); + backlog_queues_.pop_front(); + if (!requests->empty()) { // should never be empty... + const auto& irequest = requests->back(); + const InferenceRequest::SequenceId& correlation_id = + irequest->CorrelationId(); + + // If the last queue entry is not an END request then the entire + // sequence is not contained in the backlog. In that case must + // update backlog and batcherseqslot maps so that future + // requests get directed to the batcher sequence-slot instead of + // the backlog. + const bool seq_end = + ((irequest->Flags() & TRITONSERVER_REQUEST_FLAG_SEQUENCE_END) != 0); + if (!seq_end) { + // Since the correlation ID is being actively collected in the + // backlog, there should not be any in-flight sequences with + // that same correlation ID that have an assigned slot. + if (sequence_to_batcherseqslot_map_.find(correlation_id) != + sequence_to_batcherseqslot_map_.end()) { + LOG_ERROR << irequest->LogRequest() << "internal: backlog sequence " + << correlation_id + << " conflicts with in-flight sequence for model '" + << irequest->ModelName() << "'"; + } + + sequence_to_backlog_map_.erase(correlation_id); + sequence_to_batcherseqslot_map_[correlation_id] = batcher_seq_slot; + } + + LOG_VERBOSE(1) << irequest->LogRequest() << "CORRID " << correlation_id + << " reusing batcher " << batcher_seq_slot.batcher_idx_ + << ", slot " << batcher_seq_slot.seq_slot_ << ": " + << irequest->ModelName(); + return correlation_id; + } + } + + // There is no backlogged sequence so just release the batch slot + LOG_VERBOSE(1) << "Freeing slot in batcher " << batcher_seq_slot.batcher_idx_ + << ", slot " << batcher_seq_slot.seq_slot_; + + ready_batcher_seq_slots_.push(batcher_seq_slot); + return InferenceRequest::SequenceId(); +} + +bool +SequenceBatchScheduler::DelayScheduler( + const uint32_t batcher_idx, const size_t cnt, const size_t total) +{ + std::unique_lock lock(mu_); + queue_request_cnts_[batcher_idx] = cnt; + + size_t seen = 0; + for (auto c : queue_request_cnts_) { + seen += c; + } + + if (seen < total) { + return true; + } + + if (backlog_delay_cnt_ > 0) { + size_t backlog_seen = 0; + for (const auto& q : backlog_queues_) { + backlog_seen += q->size(); + } + + if (backlog_seen < backlog_delay_cnt_) { + return true; + } + } + + return false; +} + +void +SequenceBatchScheduler::ReaperThread(const int nice) +{ +#ifndef _WIN32 + if (setpriority(PRIO_PROCESS, syscall(SYS_gettid), nice) == 0) { + LOG_VERBOSE(1) << "Starting sequence-batch reaper thread at nice " << nice + << "..."; + } else { + LOG_VERBOSE(1) << "Starting sequence-batch reaper thread at default nice " + "(requested nice " + << nice << " failed)..."; + } +#else + LOG_VERBOSE(1) << "Starting sequence-batch reaper thread at default nice..."; +#endif + + const uint64_t backlog_idle_wait_microseconds = 50 * 1000; + + while (!reaper_thread_exit_) { + uint64_t wait_microseconds = max_sequence_idle_microseconds_; + BatcherSequenceSlotMap force_end_sequences; + + { + std::unique_lock lock(mu_); + + uint64_t now_us = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + + for (auto cid_itr = correlation_id_timestamps_.cbegin(); + cid_itr != correlation_id_timestamps_.cend();) { + int64_t remaining_microseconds = + (int64_t)max_sequence_idle_microseconds_ - + (now_us - cid_itr->second); + if (remaining_microseconds > 0) { + wait_microseconds = + std::min(wait_microseconds, (uint64_t)remaining_microseconds + 1); + ++cid_itr; + continue; + } + + const InferenceRequest::SequenceId& idle_correlation_id = + cid_itr->first; + LOG_VERBOSE(1) << "Reaper: CORRID " << idle_correlation_id + << ": max sequence idle exceeded"; + + auto idle_sb_itr = + sequence_to_batcherseqslot_map_.find(idle_correlation_id); + + // If the idle correlation ID has an assigned sequence slot, + // then release that assignment so it becomes available for + // another sequence. Release is done by enqueuing and must be + // done outside the lock, so just collect needed info here. + if (idle_sb_itr != sequence_to_batcherseqslot_map_.end()) { + force_end_sequences[idle_correlation_id] = idle_sb_itr->second; + + sequence_to_batcherseqslot_map_.erase(idle_correlation_id); + cid_itr = correlation_id_timestamps_.erase(cid_itr); + } else { + // If the idle correlation ID is in the backlog, then just + // need to increase the timeout so that we revisit it again in + // the future to check if it is assigned to a sequence slot. + auto idle_bl_itr = sequence_to_backlog_map_.find(idle_correlation_id); + if (idle_bl_itr != sequence_to_backlog_map_.end()) { + LOG_VERBOSE(1) << "Reaper: found idle CORRID " + << idle_correlation_id; + wait_microseconds = + std::min(wait_microseconds, backlog_idle_wait_microseconds); + ++cid_itr; + } else { + LOG_VERBOSE(1) << "Reaper: ignoring stale idle CORRID " + << idle_correlation_id; + cid_itr = correlation_id_timestamps_.erase(cid_itr); + } + } + } + } + + // Enqueue force-ends outside of the lock. + for (const auto& pr : force_end_sequences) { + const InferenceRequest::SequenceId& idle_correlation_id = pr.first; + const size_t batcher_idx = pr.second.batcher_idx_; + const uint32_t seq_slot = pr.second.seq_slot_; + + LOG_VERBOSE(1) << "Reaper: force-ending CORRID " << idle_correlation_id + << " in batcher " << batcher_idx << ", slot " << seq_slot; + + // A slot assignment is released by enqueuing a request with a + // null request. The scheduler thread will interpret the null + // request as meaning it should release the sequence slot but + // otherwise do nothing with the request. + std::unique_ptr null_request; + batchers_[batcher_idx]->Enqueue( + seq_slot, idle_correlation_id, null_request); + } + + // Wait until the next idle timeout needs to be checked + if (wait_microseconds > 0) { + std::unique_lock lock(mu_); + LOG_VERBOSE(2) << "Reaper: sleeping for " << wait_microseconds << "us..."; + std::chrono::microseconds wait_timeout(wait_microseconds); + reaper_cv_.wait_for(lock, wait_timeout); + } + } + + LOG_VERBOSE(1) << "Stopping sequence-batch reaper thread..."; +} + +SequenceBatch::SequenceBatch( + SequenceBatchScheduler* base, const uint32_t batcher_idx, + const size_t seq_slot_cnt, + const std::unordered_map& enforce_equal_shape_tensors, + const bool has_optional_input, + const std::shared_ptr& + start_input_overrides, + const std::shared_ptr& + end_input_overrides, + const std::shared_ptr& + startend_input_overrides, + const std::shared_ptr& + continue_input_overrides, + const std::shared_ptr& + notready_input_overrides) + : base_(base), batcher_idx_(batcher_idx), seq_slot_cnt_(seq_slot_cnt), + enforce_equal_shape_tensors_(enforce_equal_shape_tensors), + has_optional_input_(has_optional_input), + start_input_overrides_(start_input_overrides), + end_input_overrides_(end_input_overrides), + startend_input_overrides_(startend_input_overrides), + continue_input_overrides_(continue_input_overrides), + notready_input_overrides_(notready_input_overrides), + sequence_states_(seq_slot_cnt) +{ +} + +bool +SequenceBatch::CreateCorrelationIDControl(const inference::ModelConfig& config) +{ + // If model wants CORRID control then get the name of the input + // tensor and initialize the override structure for each sequence + // slot that is used to communicate the correlation ID. + std::string correlation_id_tensor_name; + inference::DataType correlation_id_datatype; + Status corrid_status = GetTypedSequenceControlProperties( + config.sequence_batching(), config.name(), + inference::ModelSequenceBatching::Control::CONTROL_SEQUENCE_CORRID, + false /* required */, &correlation_id_tensor_name, + &correlation_id_datatype); + if (!corrid_status.IsOk()) { + LOG_ERROR << "failed validating CORRID control for sequence-batch " + "scheduler thread " + << batcher_idx_ << ": " << corrid_status.Message(); + return false; + } + + if (!correlation_id_tensor_name.empty()) { + if ((correlation_id_datatype != inference::DataType::TYPE_UINT64) && + (correlation_id_datatype != inference::DataType::TYPE_INT64) && + (correlation_id_datatype != inference::DataType::TYPE_UINT32) && + (correlation_id_datatype != inference::DataType::TYPE_INT32) && + (correlation_id_datatype != inference::DataType::TYPE_STRING)) { + LOG_ERROR << "unexpected control data type, expected TYPE_UINT64, " + "TYPE_INT64, TYPE_UINT32, TYPE_INT32, or TYPE_STRING for " + << inference::ModelSequenceBatching_Control_Kind_Name( + inference::ModelSequenceBatching::Control:: + CONTROL_SEQUENCE_CORRID) + << " for " << config.name(); + return false; + } + + const std::vector tensor_shape{1}; + std::vector tensor_shape_with_batch_dim{1}; + if (config.max_batch_size() != 0) { + tensor_shape_with_batch_dim.push_back(1); + } + + auto override = std::make_shared( + correlation_id_tensor_name, correlation_id_datatype, tensor_shape); + *override->MutableShape() = override->OriginalShape(); + *override->MutableShapeWithBatchDim() = tensor_shape_with_batch_dim; + + seq_slot_corrid_override_ = std::move(override); + } + + return true; +} + +void +SequenceBatch::SetControlTensors( + std::unique_ptr& irequest, const int32_t seq_slot, + const InferenceRequest::SequenceId& corrid, const bool not_ready) +{ + const SequenceBatchScheduler::ControlInputs* controls; + + // Set the start, end, and ready control tensors appropriately... + if (not_ready) { + controls = notready_input_overrides_.get(); + } else if ( + (irequest->Flags() & (TRITONSERVER_REQUEST_FLAG_SEQUENCE_START | + TRITONSERVER_REQUEST_FLAG_SEQUENCE_END)) == + (TRITONSERVER_REQUEST_FLAG_SEQUENCE_START | + TRITONSERVER_REQUEST_FLAG_SEQUENCE_END)) { + controls = startend_input_overrides_.get(); + } else if ( + (irequest->Flags() & TRITONSERVER_REQUEST_FLAG_SEQUENCE_START) != 0) { + controls = start_input_overrides_.get(); + } else if ( + (irequest->Flags() & TRITONSERVER_REQUEST_FLAG_SEQUENCE_END) != 0) { + controls = end_input_overrides_.get(); + } else { + controls = continue_input_overrides_.get(); + } + + for (const auto& control : *controls) { + irequest->AddOverrideInput(control); + } + + // Set correlation ID control tensor if requested by the model. + if (seq_slot_corrid_override_ != nullptr) { + auto& seq_corr_id = seq_slot_corrid_override_; + size_t size_p = triton::common::GetDataTypeByteSize(seq_corr_id->DType()); + if (seq_corr_id->DType() == inference::DataType::TYPE_STRING) { + // 4 bytes for length of string plus pre-defined max string correlation id + // length in bytes + size_p = 4 + triton::core::STRING_CORRELATION_ID_MAX_LENGTH_BYTES; + } + + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id; + auto corrid_p = + std::make_shared(size_p, TRITONSERVER_MEMORY_CPU, 0); + char* corrid_p_ptr = corrid_p->MutableBuffer(&memory_type, &memory_type_id); + if ((corrid_p_ptr == nullptr) || + ((memory_type != TRITONSERVER_MEMORY_CPU) && + (memory_type != TRITONSERVER_MEMORY_CPU_PINNED)) || + (memory_type_id != 0)) { + LOG_ERROR << "failed to allocate sequence CORRID control signal in CPU " + "memory"; + return; + } + + auto override = std::make_shared( + seq_corr_id->Name(), seq_corr_id->DType(), seq_corr_id->Shape()); + *override->MutableShape() = override->OriginalShape(); + *override->MutableShapeWithBatchDim() = seq_corr_id->ShapeWithBatchDim(); + Status corrid_status = override->SetData(corrid_p); + if (!corrid_status.IsOk()) { + LOG_ERROR << "failed creating CORRID control for sequence-batch " + "scheduler thread " + << batcher_idx_ << " for " << seq_corr_id->Name(); + return; + } + + if (corrid.Type() == InferenceRequest::SequenceId::DataType::STRING) { + std::string correlation_id = corrid.StringValue(); + uint32_t correlation_id_length = correlation_id.length(); + memcpy(corrid_p_ptr, &correlation_id_length, sizeof(uint32_t)); + memcpy( + corrid_p_ptr + sizeof(uint32_t), correlation_id.c_str(), + correlation_id_length); + } else if ( + corrid.Type() == InferenceRequest::SequenceId::DataType::UINT64) { + uint64_t correlation_id = corrid.UnsignedIntValue(); + const char* corrid_ptr = reinterpret_cast(&correlation_id); + memcpy(corrid_p_ptr, corrid_ptr, size_p); + } + irequest->AddOverrideInput(override); + } +} + +void +SequenceBatch::UpdateImplicitState( + std::unique_ptr& irequest, const int32_t seq_slot) +{ + // This should be executed only if the model has a states section. + if (!base_->StateOutputConfigMap().empty()) { + auto& sequence_states = sequence_states_[seq_slot]; + + // Initialize the input state if the sequence is starting. + if ((irequest->Flags() & TRITONSERVER_REQUEST_FLAG_SEQUENCE_START) != 0) { + sequence_states = nullptr; + } + + // Create the state for the first request in the sequence. + if (sequence_states == nullptr) { + sequence_states.reset(new SequenceStates); + sequence_states->Initialize( + base_->StateOutputConfigMap(), base_->MaxBatchSize(), + base_->InitialState()); + } + + irequest->SetSequenceStates(sequence_states); + } +} + +DirectSequenceBatch::DirectSequenceBatch( + SequenceBatchScheduler* base, const uint32_t batcher_idx, + const size_t seq_slot_cnt, TritonModelInstance* model_instance, + const std::unordered_map& enforce_equal_shape_tensors, + const bool has_optional_input, + const std::shared_ptr& + start_input_overrides, + const std::shared_ptr& + end_input_overrides, + const std::shared_ptr& + startend_input_overrides, + const std::shared_ptr& + continue_input_overrides, + const std::shared_ptr& + notready_input_overrides, + bool* is_initialized) + : SequenceBatch( + base, batcher_idx, seq_slot_cnt, enforce_equal_shape_tensors, + has_optional_input, start_input_overrides, end_input_overrides, + startend_input_overrides, continue_input_overrides, + notready_input_overrides), + model_instance_(model_instance), scheduler_thread_exit_(false), + scheduler_idle_(false), queues_(seq_slot_cnt), + seq_slot_correlation_ids_(seq_slot_cnt, 0), max_active_seq_slot_(-1) +{ + // Initialize to handle CORRID control. If error just exit + // now... that means the corresponding model instance will not have + // any runner and so will not get used for execution. + const auto& config = model_instance_->Model()->Config(); + if (!CreateCorrelationIDControl(config)) { + *is_initialized = false; + return; + } + + max_batch_size_ = ((size_t)std::max(1, config.max_batch_size())); + minimum_slot_utilization_ = + config.sequence_batching().direct().minimum_slot_utilization(); + pending_batch_delay_ns_ = + config.sequence_batching().direct().max_queue_delay_microseconds() * 1000; + + // Create a scheduler thread associated with 'batcher_idx' that + // executes the queued requests. + const int nice = 0; + NewPayload(); + scheduler_thread_.reset( + new std::thread([this, nice]() { BatcherThread(nice); })); + + *is_initialized = true; +} + +DirectSequenceBatch::~DirectSequenceBatch() +{ + // Signal the scheduler thread to exit... + { + std::unique_lock lock(mu_); + scheduler_thread_exit_ = true; + } + + cv_.notify_one(); + + // It is possible for the scheduler thread to be the last holder of + // a model object, and when that scheduler thread releases the + // object the scheduler thread itself will destroy this + // SequenceBatch object. So we need to check to make sure the + // scheduler thread does not join it against itself and instead + // detach it so there is not a problem when its thread object is + // destroyed. + if (scheduler_thread_->joinable()) { + scheduler_thread_->join(); + } +} + +void +DirectSequenceBatch::Enqueue( + const uint32_t seq_slot, const InferenceRequest::SequenceId& correlation_id, + std::unique_ptr& request) +{ + bool wake_runner = false; + + { + std::lock_guard lock(mu_); + + queues_[seq_slot].emplace_back(std::move(request)); + + seq_slot_correlation_ids_[seq_slot] = correlation_id; + max_active_seq_slot_ = + std::max(max_active_seq_slot_, static_cast(seq_slot)); + + // If runner is idle then wake it to service this request. We do + // the actual wake outside of the lock to avoid having the woken + // thread immediately block on the lock + wake_runner = scheduler_idle_; + } + + if (wake_runner) { + cv_.notify_one(); + } +} + +void +DirectSequenceBatch::NewPayload() +{ + curr_payload_ = + model_instance_->Model()->Server()->GetRateLimiter()->GetPayload( + Payload::Operation::INFER_RUN, model_instance_); +} + +void +DirectSequenceBatch::BatcherThread(const int nice) +{ +#ifndef _WIN32 + if (setpriority(PRIO_PROCESS, syscall(SYS_gettid), nice) == 0) { + LOG_VERBOSE(1) << "Starting Direct sequence-batch scheduler thread " + << batcher_idx_ << " at nice " << nice << "..."; + } else { + LOG_VERBOSE(1) << "Starting Direct sequence-batch scheduler thread " + << batcher_idx_ << " at default nice (requested nice " + << nice << " failed)..."; + } +#else + LOG_VERBOSE(1) << "Starting Direct sequence-batch scheduler thread " + << batcher_idx_ << " at default nice..."; +#endif + + // For debugging and testing, delay start of thread until queues + // contain the specified number of entries (across all + // SequenceBatchs in the scheduler). + const char* dstr = getenv("TRITONSERVER_DELAY_SCHEDULER"); + size_t delay_cnt = 0; + if (dstr != nullptr) { + delay_cnt = atoi(dstr); + LOG_VERBOSE(1) << "Delaying scheduler thread " << batcher_idx_ << " until " + << delay_cnt << " queued requests..."; + } + + const uint64_t default_wait_microseconds = 500 * 1000; + exec_complete_ = true; + + // When there is optional input or input shape must be enforced, + // the inputs in the requests must be examined for forming a batch + const bool check_input = + !enforce_equal_shape_tensors_.empty() || has_optional_input_; + while (!scheduler_thread_exit_) { + uint64_t wait_microseconds = default_wait_microseconds; + + // Wait till execution of the last enqueued payload is + // complete. + { + std::unique_lock lk(payload_mu_); + payload_cv_.wait(lk, [this] { return exec_complete_; }); + } + + // Hold the lock for as short a time as possible. + { + std::unique_lock lock(mu_); + + if (delay_cnt > 0) { + wait_microseconds = 10 * 1000; + // Debugging/testing... wait until queues together contain at + // least 'delay_cnt' items... + size_t total_size = 0; + for (const auto& q : queues_) { + total_size += q.size(); + } + if (!base_->DelayScheduler(batcher_idx_, total_size, delay_cnt)) { + delay_cnt = 0; + } + LOG_VERBOSE(1) << "Delaying scheduler thread " << batcher_idx_ + << " until " << delay_cnt + << " queued requests, current total = " << total_size; + } else { + RequiredEqualInputs required_equal_inputs; + InferenceRequest* null_irequest = nullptr; + + // Make one pass through the active slots to: + // + // 1) release any slots that have forcibly ended sequences + // + // 2) find a representative request that will provide: + // + // a) the shape, type, etc. information for null requests + // + // b) the required tensor shapes for the batch for the + // case where ragged batching is not allowed + // + // 3) Determine the earliest enqueue time and number of ready + // sequences if queue delay is enabled + // + int32_t max_seq_slot = -1; + uint64_t earliest_enqueue_time_ns = UINT64_MAX; + size_t ready_cnt = 0; + for (int32_t seq_slot = 0; seq_slot <= max_active_seq_slot_; + ++seq_slot) { + std::deque>& queue = + queues_[seq_slot]; + if (!queue.empty()) { + // If the request is nullptr then the sequence in the slot + // has timed-out so release the slot for another sequence + // from the backlog. + if (queue.front() == nullptr) { + queue.pop_front(); + + SequenceBatchScheduler::BatcherSequenceSlot batcher_seq_slot( + batcher_idx_, seq_slot); + seq_slot_correlation_ids_[seq_slot] = + base_->ReleaseSequenceSlot(batcher_seq_slot, &queue); + } + } + + // Need to check queue again for contents since if released + // above it may now be empty... + if (!queue.empty()) { + // For NULL requests need an InferenceRequest that can be + // batched but has controls set to "not ready". Any + // request can serve this purpose so grab a copy of the + // first one. This first request is also used to + // initialize 'required_equal_inputs' so we are sure that + // this null request will have the correct shape for any + // created batch. + if (null_irequest == nullptr) { + null_irequest = queue.front().get(); + UpdateImplicitState(queue.front(), seq_slot); + } + + // If this is the first non-null request capture the shape + // of the tensors that don't support ragged so we can + // compare them to later requests. + if (!required_equal_inputs.Initialized() && check_input) { + Status status = required_equal_inputs.Initialize( + queue.front(), enforce_equal_shape_tensors_, + has_optional_input_); + if (!status.IsOk()) { + LOG_ERROR + << "internal: unexpecting failure initializing shape: " + << status.Message(); + } + } + + earliest_enqueue_time_ns = std::min( + earliest_enqueue_time_ns, queue.front()->BatcherStartNs()); + ready_cnt++; + max_seq_slot = seq_slot; + } + } + + if (max_seq_slot != -1) { + if ((pending_batch_delay_ns_ == 0) || + (minimum_slot_utilization_ == 0.0)) { + wait_microseconds = 0; + } else { + // Compare the age of the oldest pending request to the maximum + // batch queuing delay, and the size of the ready requests in the + // batch, execute now if queuing delay is exceeded or the batch + // size is large enough. Otherwise create a timer to wakeup a + // thread to check again at the maximum allowed delay. + uint64_t now_ns = + std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + uint64_t current_batch_delay_ns = + (now_ns - earliest_enqueue_time_ns); + if ((current_batch_delay_ns > pending_batch_delay_ns_) || + (((float)ready_cnt) / max_batch_size_ >= + minimum_slot_utilization_)) { + wait_microseconds = 0; + LOG_VERBOSE(1) + << "start sequence batch execution. " + << "current batch delay: " << current_batch_delay_ns + << "; maximum delay allowed: " << pending_batch_delay_ns_ + << "slot utilization: " << ready_cnt << "/" << max_batch_size_ + << "; utilization threshold: " << minimum_slot_utilization_; + } else { + wait_microseconds = + (pending_batch_delay_ns_ - current_batch_delay_ns) / 1000; + // reset 'max_seq_slot' so that not request is pulled from the + // queues + max_seq_slot = -1; + LOG_VERBOSE(1) + << "defer sequence batch execution. " + << "current batch delay: " << current_batch_delay_ns + << "; maximum delay allowed: " << pending_batch_delay_ns_ + << "slot utilization: " << ready_cnt << "/" << max_batch_size_ + << "; utilization threshold: " << minimum_slot_utilization_; + } + } + } + + // Collect requests from slot 0 to max_seq_slot. + for (int32_t seq_slot = 0; seq_slot <= max_seq_slot; ++seq_slot) { + bool end_of_sequence = false; + bool use_null_request = false; + std::deque>& queue = + queues_[seq_slot]; + + // If 'seq_slot' doesn't have any requests then change the + // request to send dummy/null input tensors for this + // slot. We need this so that other requests stay in the + // correct slot. + if (queue.empty()) { + use_null_request = true; + } + // If there are one or more tensors that don't support + // ragged batch, then don't allow a request into an existing + // batch if shape differs. + else if (required_equal_inputs.Initialized() && check_input) { + if (!required_equal_inputs.HasEqualInputs(queue.front())) { + use_null_request = true; + } + } + + // Use null-request if necessary otherwise use the next + // request in the queue... + if (use_null_request) { + std::unique_ptr ni( + InferenceRequest::CopyAsNull(*null_irequest)); + // Note that when the not-ready control input of the + // request is "true" the model can't assume that any + // other inputs are meaningful, including CORRID. So we + // just use zero for that. + SetControlTensors( + ni, seq_slot, 0 /* corrid */, true /* not_ready */); + + // This should be executed only if the model has a states section. + if (!base_->StateOutputConfigMap().empty()) { + // For NULL requests we will be using a dummy state instead of the + // real state stored in Triton. When the model is using variable + // dimensions and batching, the null request's input state shapes + // may be different from the actual shapes of the state for that + // sequence. We create a dummy state in order to avoid corrupting + // the actual state of the sequence. + std::shared_ptr sequence_states( + new SequenceStates); + sequence_states->SetNullSequenceStates( + null_irequest->GetSequenceStates()); + ni->SetSequenceStates(sequence_states); + } + + curr_payload_->AddRequest(std::move(ni)); + } else { + std::unique_ptr& irequest = queue.front(); + + // Set the control tensor values in the request. + SetControlTensors(irequest, seq_slot, irequest->CorrelationId()); + + // Update the implicit state and set the input state tensors. + UpdateImplicitState(irequest, seq_slot); + + if ((irequest->Flags() & TRITONSERVER_REQUEST_FLAG_SEQUENCE_END) != + 0) { + end_of_sequence = true; + } + curr_payload_->AddRequest(std::move(irequest)); + + queue.pop_front(); + } + + if (curr_payload_->GetState() == Payload::State::UNINITIALIZED) { + curr_payload_->SetState(Payload::State::READY); + } + + // If the sequence has ended then attempt to refill the + // sequence slot with a sequence from the backlog. If + // there is no backlog show that the slot is no longer + // active. + if (end_of_sequence) { + LOG_VERBOSE(1) << "End sequence CORRID " + << seq_slot_correlation_ids_[seq_slot] + << " in batcher " << batcher_idx_ << ", slot " + << seq_slot; + + // Should never be anything in a queue after the END + // marker. If it happens that means we will clobber + // that request if/when we swap in a backlog sequence + // in ReleaseSequenceSlot below. + if (!queue.empty()) { + LOG_ERROR << "internal: unexpected requests after sequence " + "end in slot " + << seq_slot; + } + + SequenceBatchScheduler::BatcherSequenceSlot batcher_seq_slot( + batcher_idx_, seq_slot); + seq_slot_correlation_ids_[seq_slot] = + base_->ReleaseSequenceSlot(batcher_seq_slot, &queue); + } + } + } + + // One or more sequences may have ended... find the new + // 'max_active_seq_slot_'. + while ((max_active_seq_slot_ >= 0) && + (!seq_slot_correlation_ids_[max_active_seq_slot_].InSequence())) { + max_active_seq_slot_--; + } + + // If no requests are to be handled, wait for notification or + // for the specified timeout before checking the queues again. + if (wait_microseconds > 0) { + scheduler_idle_ = true; + std::chrono::microseconds wait_timeout(wait_microseconds); + cv_.wait_for(lock, wait_timeout); + scheduler_idle_ = false; + } + } + + if (curr_payload_->GetState() == Payload::State::READY) { + // Add callback to signal the execution completion + exec_complete_ = false; + auto callback = [this]() { + { + std::unique_lock lk(payload_mu_); + exec_complete_ = true; + } + payload_cv_.notify_one(); + }; + curr_payload_->AddInternalReleaseCallback(callback); + curr_payload_->MarkSaturated(); + + // Enqueue the payload to RateLimiter + model_instance_->Model()->Server()->GetRateLimiter()->EnqueuePayload( + model_instance_->Model(), curr_payload_); + NewPayload(); + } + } // end runner loop + + LOG_VERBOSE(1) << "Stopping Direct sequence-batch scheduler thread " + << batcher_idx_ << "..."; +} + +OldestSequenceBatch::OldestSequenceBatch( + SequenceBatchScheduler* base, const uint32_t batcher_idx, + const size_t seq_slot_cnt, TritonModelInstance* model_instance, + const std::unordered_map& enforce_equal_shape_tensors, + const bool has_optional_input, + const std::shared_ptr& + start_input_overrides, + const std::shared_ptr& + end_input_overrides, + const std::shared_ptr& + startend_input_overrides, + const std::shared_ptr& + continue_input_overrides, + const std::shared_ptr& + notready_input_overrides, + bool* is_initialized) + : SequenceBatch( + base, batcher_idx, seq_slot_cnt, enforce_equal_shape_tensors, + has_optional_input, start_input_overrides, end_input_overrides, + startend_input_overrides, continue_input_overrides, + notready_input_overrides), + in_flight_(seq_slot_cnt, false), queues_(seq_slot_cnt) +{ + // Initialize to handle CORRID control. If error just exit + // now... that means the corresponding model instance will not have + // any runner and so will not get used for execution. + const auto& config = model_instance->Model()->Config(); + if (!CreateCorrelationIDControl(config)) { + *is_initialized = false; + return; + } + + // Create a dynamic batcher use to batch together sequences for + // inference. + std::set preferred_batch_sizes; + for (const auto size : + config.sequence_batching().oldest().preferred_batch_size()) { + preferred_batch_sizes.insert(size); + } + + // TODO: Provide appropriate request_cache_enable flag when caching + // is enabled for sequence models. + Status status = DynamicBatchScheduler::Create( + model_instance->Model(), model_instance, + triton::common::GetCpuNiceLevel(config), + true /* dynamic_batching_enabled */, config.max_batch_size(), + enforce_equal_shape_tensors_, true /* preserve_ordering */, + false /* response_cache_enable */, preferred_batch_sizes, + config.sequence_batching().oldest().max_queue_delay_microseconds(), + &dynamic_batcher_); + if (!status.IsOk()) { + LOG_ERROR << "failed creating dynamic sequence batcher for OldestFirst " + << batcher_idx_ << ": " << status.Message(); + *is_initialized = false; + return; + } + + *is_initialized = true; +} +OldestSequenceBatch::~OldestSequenceBatch() {} + +void +OldestSequenceBatch::CompleteAndNext(const uint32_t seq_slot) +{ + std::lock_guard lock(mu_); + + // We may enqueue 1 or more pending inferences triggered by the + // completion. If the sequence has a pending inference then it needs + // to be send to dynamic batcher since the "previous" inference just + // completed. If this next inference ends up being the end of the + // sequence (either from the END flag or because the sequence is + // being force-ended) then we try to fill the now-free sequence slot + // from the backlog and then send the first inference from that + // sequence to the dynamic batcher... + std::deque>& queue = queues_[seq_slot]; + bool retry = true; + while (retry) { + retry = false; + + bool release_seq_slot = false; + in_flight_[seq_slot] = false; + + // If the next sequence inference is ready in the queue then enqueue + // it in the dynamic batcher now. + if (!queue.empty()) { + auto& irequest = queue.front(); + + // If the request is null then this inference request is from + // the reaper thread indicating a timed-out sequence. Mark that + // the sequence slot should be released but otherwise do + // nothing. + if (irequest == nullptr) { + LOG_VERBOSE(1) << irequest->LogRequest() + << "force-end sequence in batcher " << batcher_idx_ + << ", slot " << seq_slot; + release_seq_slot = true; + } else { + const InferenceRequest::SequenceId& correlation_id = + irequest->CorrelationId(); + + // After handling the last inference in a sequence we must + // release the sequence slot to make it available to another + // sequence. + if ((irequest->Flags() & TRITONSERVER_REQUEST_FLAG_SEQUENCE_END) != 0) { + LOG_VERBOSE(1) << irequest->LogRequest() << "end sequence CORRID " + << correlation_id << " in batcher " << batcher_idx_ + << ", slot " << seq_slot; + release_seq_slot = true; + } + + // Add the appropriate control tensor values to the request. + SetControlTensors(irequest, seq_slot, correlation_id); + + // Update the implicit state and set the input state tensors. + UpdateImplicitState(irequest, seq_slot); + + LOG_VERBOSE(1) << irequest->LogRequest() + << "issue to dynamic batcher CORRID " << correlation_id + << " in batcher " << batcher_idx_ << ", slot " + << seq_slot; + in_flight_[seq_slot] = true; + + irequest->AddInternalReleaseCallback( + [this, seq_slot]() { CompleteAndNext(seq_slot); }); + + dynamic_batcher_->Enqueue(irequest); + } + + queue.pop_front(); + } + + // If releasing the sequence slot then the sequence queue should be + // empty and we can now assign a new sequence to the queue (from the + // backlog). + if (release_seq_slot) { + // Should never be anything in a queue after the END marker. If it + // happens that means we will clobber that request if/when we swap + // in a backlog sequence in ReleaseSequenceSlot below. + if (!queue.empty()) { + LOG_ERROR << "internal: unexpected requests after sequence end in slot " + << seq_slot; + } + + SequenceBatchScheduler::BatcherSequenceSlot batcher_seq_slot( + batcher_idx_, seq_slot); + const InferenceRequest::SequenceId& released_cid = + base_->ReleaseSequenceSlot(batcher_seq_slot, &queue); + + if (released_cid.InSequence()) { + LOG_VERBOSE(1) << "Enqueued new sequence containing " << queue.size() + << " requests into OldestFirst batcher " << batcher_idx_ + << ", slot " << seq_slot; + + // If an inference is already in-flight in the dynamic batcher + // in this sequence slot then can't process the new queue + // inferences right now, because the in-flight request is + // using slot resources like the CORRID override map. + if (!in_flight_[seq_slot]) { + retry = true; + } + } + } + } +} + +void +OldestSequenceBatch::Enqueue( + const uint32_t seq_slot, const InferenceRequest::SequenceId& correlation_id, + std::unique_ptr& request) +{ + // Queue the new request... if there isn't already a request in + // flight for this sequence then send one to the dynamic batcher + // immediately. + bool in_flight; + { + std::lock_guard lock(mu_); + + std::deque>& queue = queues_[seq_slot]; + queue.emplace_back(std::move(request)); + in_flight = in_flight_[seq_slot]; + } + + if (!in_flight) { + CompleteAndNext(seq_slot); + } +} +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/sequence_batch_scheduler.h b/3rdparty/core-r22.12/src/sequence_batch_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..44b7594717b2610917c78e24eb5522a83077b954 --- /dev/null +++ b/3rdparty/core-r22.12/src/sequence_batch_scheduler.h @@ -0,0 +1,399 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "backend_model.h" +#include "backend_model_instance.h" +#include "model_config.pb.h" +#include "rate_limiter.h" +#include "scheduler.h" +#include "scheduler_utils.h" +#include "sequence_state.h" +#include "status.h" +#include "triton/common/model_config.h" + +namespace triton { namespace core { + +class SequenceBatch; + +// Scheduler that implements batching across sequences of correlated +// inferences. +class SequenceBatchScheduler : public Scheduler { + public: + using ControlInputs = std::vector>; + + SequenceBatchScheduler() = default; + ~SequenceBatchScheduler(); + + // Create a scheduler to support a given number of runners and a run + // function to call when a request is scheduled. + static Status Create( + TritonModel* model, + const std::unordered_map& enforce_equal_shape_tensors, + std::unique_ptr* scheduler); + + // \see Scheduler::Enqueue() + Status Enqueue(std::unique_ptr& request) override; + + // \see Scheduler::InflightInferenceCount() + size_t InflightInferenceCount() override + { + std::unique_lock lock(mu_); + return sequence_to_batcherseqslot_map_.size(); + } + + // \see Scheduler::Stop() + void Stop() override { stop_ = true; } + + // A batcher-sequence_slot combination. The batcher is represented + // by the index into 'batchers_'. + struct BatcherSequenceSlot { + BatcherSequenceSlot() = default; + BatcherSequenceSlot(const BatcherSequenceSlot&) = default; + BatcherSequenceSlot(size_t b, uint32_t s) : batcher_idx_(b), seq_slot_(s) {} + size_t batcher_idx_; + uint32_t seq_slot_; + }; + + // Fill a sequence slot with a sequence from the backlog or show + // that the sequence slot is no longer being used. + InferenceRequest::SequenceId ReleaseSequenceSlot( + const BatcherSequenceSlot& seq_slot, + std::deque>* requests); + + // For debugging/testing, batcher reports how many waiting requests + // and returns true if the batcher should continue waiting. + bool DelayScheduler( + const uint32_t batcher_idx, const size_t cnt, const size_t total); + + const std::unordered_map< + std::string, const inference::ModelSequenceBatching_State&>& + StateOutputConfigMap() + { + return state_output_config_map_; + } + + size_t MaxBatchSize() { return max_batch_size_; } + const std::unordered_map& + InitialState() + { + return initial_state_; + } + + private: + void ReaperThread(const int nice); + + Status CreateBooleanControlTensors( + const inference::ModelConfig& config, + std::shared_ptr* start_input_overrides, + std::shared_ptr* end_input_overrides, + std::shared_ptr* startend_input_overrides, + std::shared_ptr* continue_input_overrides, + std::shared_ptr* notready_input_overrides); + + Status GenerateInitialStateData( + const inference::ModelSequenceBatching_InitialState& initial_state, + const inference::ModelSequenceBatching_State& state, TritonModel* model); + + struct BatcherSequenceSlotCompare { + bool operator()( + const BatcherSequenceSlot& a, const BatcherSequenceSlot& b) const + { + return a.seq_slot_ > b.seq_slot_; + } + }; + + // The max_sequence_idle_microseconds value for this scheduler. + uint64_t max_sequence_idle_microseconds_; + + bool stop_; + + // Mutex + std::mutex mu_; + + // The reaper thread + std::unique_ptr reaper_thread_; + std::condition_variable reaper_cv_; + bool reaper_thread_exit_; + + // The SequenceBatchs being managed by this scheduler. + std::vector> batchers_; + + // Map from a request's correlation ID to the BatcherSequenceSlot + // assigned to that correlation ID. + using BatcherSequenceSlotMap = + std::unordered_map; + BatcherSequenceSlotMap sequence_to_batcherseqslot_map_; + + // Map from a request's correlation ID to the backlog queue + // collecting requests for that correlation ID. + using BacklogMap = std::unordered_map< + InferenceRequest::SequenceId, + std::shared_ptr>>>; + BacklogMap sequence_to_backlog_map_; + + // The ordered backlog of sequences waiting for a free sequenceslot. + std::deque>>> + backlog_queues_; + + // The batcher/sequence-slot locations ready to accept a new + // sequence. Ordered from lowest sequence-slot-number to highest so + // that all batchers grow at the same rate and attempt to remain as + // small as possible. + std::priority_queue< + BatcherSequenceSlot, std::vector, + BatcherSequenceSlotCompare> + ready_batcher_seq_slots_; + + // For each correlation ID the most recently seen timestamp, in + // microseconds, for a request using that correlation ID. + std::unordered_map + correlation_id_timestamps_; + + // Used for debugging/testing. + size_t backlog_delay_cnt_; + std::vector queue_request_cnts_; + + // IO mapping between the output state name and the state configuration. + std::unordered_map + state_output_config_map_; + size_t max_batch_size_; + + // Initial state used for implicit state. + std::unordered_map + initial_state_; +}; + +// Base class for a scheduler that implements a particular scheduling +// strategy for a model instance. +class SequenceBatch { + public: + SequenceBatch( + SequenceBatchScheduler* base, const uint32_t batcher_idx, + const size_t seq_slot_cnt, + const std::unordered_map& enforce_equal_shape_tensors, + const bool has_optional_input, + const std::shared_ptr& + start_input_overrides, + const std::shared_ptr& + end_input_overrides, + const std::shared_ptr& + startend_input_overrides, + const std::shared_ptr& + continue_input_overrides, + const std::shared_ptr& + notready_input_overrides); + virtual ~SequenceBatch() = default; + + // Enqueue a request into the appropriate queue for the requested + // sequence slot. This function takes ownership of 'request' so on + // request 'request' will be nullptr. + virtual void Enqueue( + const uint32_t seq_slot, + const InferenceRequest::SequenceId& correlation_id, + std::unique_ptr& request) = 0; + + protected: + bool CreateCorrelationIDControl(const inference::ModelConfig& config); + void SetControlTensors( + std::unique_ptr& irequest, const int32_t seq_slot, + const InferenceRequest::SequenceId& corr_id, + const bool not_ready = false); + + // Update the implicit state and set the required input states. + void UpdateImplicitState( + std::unique_ptr& irequest, const int32_t seq_slot); + + // The controlling scheduler. + SequenceBatchScheduler* const base_; + + // The index of this batcher within the controlling scheduler. + const uint32_t batcher_idx_; + + // The number of candidate sequence slots. + const size_t seq_slot_cnt_; + + // The input tensors that require shape checking before being + // allowed in a batch. As a map from the tensor name to a bool. If + // tensor is in map then its shape must match shape of same tensor + // in requests already in the batch. If value is "true" then + // additional tensor is treated as a shape tensor and the values + // contained in the shape tensor must match same tensor already in + // the batch. + const std::unordered_map enforce_equal_shape_tensors_; + + // Store information on whether the model contains optional inputs. + bool has_optional_input_; + + // The control values, delivered as input tensors, that should be + // used when starting a sequence, continuing a sequence, ending a + // sequence, and showing that a sequence has not input available. + std::shared_ptr start_input_overrides_; + std::shared_ptr end_input_overrides_; + std::shared_ptr + startend_input_overrides_; + std::shared_ptr + continue_input_overrides_; + std::shared_ptr + notready_input_overrides_; + + // The correlation ID override. Empty if model does not specify the + // CONTROL_SEQUENCE_CORRID control. + std::shared_ptr seq_slot_corrid_override_; + + // For each sequence slot store the optional state i/o tensors. + std::vector> sequence_states_; +}; + +// Scheduler that implements the Direct sequence scheduling strategy +// for a model instance. +class DirectSequenceBatch : public SequenceBatch { + public: + DirectSequenceBatch( + SequenceBatchScheduler* base, const uint32_t batcher_idx, + const size_t seq_slot_cnt, TritonModelInstance* model_instance, + const std::unordered_map& enforce_equal_shape_tensors, + const bool has_optional_input, + const std::shared_ptr& + start_input_overrides, + const std::shared_ptr& + end_input_overrides, + const std::shared_ptr& + startend_input_overrides, + const std::shared_ptr& + continue_input_overrides, + const std::shared_ptr& + notready_input_overrides, + bool* is_initialized); + ~DirectSequenceBatch(); + + void Enqueue( + const uint32_t seq_slot, + const InferenceRequest::SequenceId& correlation_id, + std::unique_ptr& request) override; + + private: + void BatcherThread(const int nice); + void NewPayload(); + + std::shared_ptr curr_payload_; + TritonModelInstance* model_instance_; + + // The thread scheduling requests that are queued in this batch. + std::unique_ptr scheduler_thread_; + bool scheduler_thread_exit_; + bool scheduler_idle_; + + // Mutex protecting correlation queues, etc. + std::mutex mu_; + std::condition_variable cv_; + + // Execution state of the last enqueued payload + bool exec_complete_; + + // Mutex protecting execution state of payload + std::mutex payload_mu_; + std::condition_variable payload_cv_; + + // Queues holding inference requests. There are 'seq_slot_cnt' + // queues, one for each sequence slot where requests assigned to + // that slot are enqueued to wait for inferencing. + std::vector>> queues_; + + // Is each sequence slot active or not? A zero or empty value indicates + // inactive, a non-zero/non-empty value indicates active and is the + // correlation ID of the sequence active in the slot. An empty + // queue for a sequence slot does not mean it's inactive... it + // could just not have any requests pending at the moment. + std::vector seq_slot_correlation_ids_; + + // The maximum active sequence slot. A value of -1 indicates that + // no slots are active in the model. + int32_t max_active_seq_slot_; + + size_t max_batch_size_; + float minimum_slot_utilization_; + uint64_t pending_batch_delay_ns_; +}; + +// Scheduler that implements the oldest-first sequence scheduling +// strategy for a model instance. +class OldestSequenceBatch : public SequenceBatch { + public: + OldestSequenceBatch( + SequenceBatchScheduler* base, const uint32_t batcher_idx, + const size_t seq_slot_cnt, TritonModelInstance* model_instance, + const std::unordered_map& enforce_equal_shape_tensors, + const bool has_optional_input, + const std::shared_ptr& + start_input_overrides, + const std::shared_ptr& + end_input_overrides, + const std::shared_ptr& + startend_input_overrides, + const std::shared_ptr& + continue_input_overrides, + const std::shared_ptr& + notready_input_overrides, + bool* is_initialized); + ~OldestSequenceBatch(); + + void Enqueue( + const uint32_t seq_slot, + const InferenceRequest::SequenceId& correlation_id, + std::unique_ptr& request) override; + + private: + void CompleteAndNext(const uint32_t seq_slot); + + // The dynamic batcher for this scheduler + std::unique_ptr dynamic_batcher_; + + TritonModelInstance* model_instance_; + + // Mutex protecting queues, etc. + std::mutex mu_; + + // For each sequence slot, true if there is a request for that + // sequence in-flight in the dynamic batcher. Used to ensure that at + // most one request from each sequence can be scheduled at a time. + std::vector in_flight_; + + // Queues holding inference requests. There are 'seq_slot_cnt' + // queues, one for each sequence slot where requests assigned to + // that slot are enqueued to wait for inferencing. + std::vector>> queues_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/sequence_state.cc b/3rdparty/core-r22.12/src/sequence_state.cc new file mode 100644 index 0000000000000000000000000000000000000000..af0605cf41c6c96d38971ee0c9189e8913424c8f --- /dev/null +++ b/3rdparty/core-r22.12/src/sequence_state.cc @@ -0,0 +1,336 @@ +// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "sequence_state.h" + +#include "memory.h" +#include "triton/common/logging.h" + +namespace triton { namespace core { + +SequenceState::SequenceState() : data_(new MemoryReference) {} + +SequenceState::SequenceState( + const std::string& name, const inference::DataType datatype, + const int64_t* shape, const uint64_t dim_count) + : name_(name), datatype_(datatype), shape_(shape, shape + dim_count), + data_(new MemoryReference) +{ +} + +SequenceState::SequenceState( + const std::string& name, const inference::DataType datatype, + const std::vector& shape) + : name_(name), datatype_(datatype), shape_(shape), + data_(new MemoryReference) +{ +} + +Status +SequenceState::SetData(const std::shared_ptr& data) +{ + if (data_->TotalByteSize() != 0) { + return Status( + Status::Code::INVALID_ARG, + "state '" + name_ + "' already has data, can't overwrite"); + } + + data_ = data; + return Status::Success; +} + +Status +SequenceState::RemoveAllData() +{ + data_ = std::make_shared(); + return Status::Success; +} + +Status +SequenceState::SetStringDataToZero() +{ + if (Data()->TotalByteSize() % 4 != 0) { + return Status( + Status::Code::INVALID_ARG, + "The total byte size must be a multiple of 4 when setting the " + "sequence state to zero."); + } + + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id; + + const std::shared_ptr& memory = + reinterpret_cast&>(Data()); + char* buffer = memory->MutableBuffer(&memory_type, &memory_type_id); + memset(buffer, 0, Data()->TotalByteSize()); + + return Status::Success; +} + +Status +SequenceStates::Initialize( + const std::unordered_map< + std::string, const inference::ModelSequenceBatching_State&>& + state_output_config_map, + const size_t max_batch_size, + const std::unordered_map& initial_state) +{ + input_states_.clear(); + output_states_.clear(); + + for (auto& state : state_output_config_map) { + auto& state_config = state.second; + + std::vector dims; + if (max_batch_size != 0) { + dims.push_back(1); + } + + // Convert the variable dimensions to 1 for the first request. + for (auto& dim : state_config.dims()) { + if (dim == -1) { + dims.push_back(1); + } else { + dims.push_back(dim); + } + } + + std::shared_ptr data; + auto initial_state_it = initial_state.find(state_config.input_name()); + if (initial_state_it != initial_state.end()) { + data = std::make_shared( + initial_state_it->second.data_->TotalByteSize(), + TRITONSERVER_MEMORY_CPU, 0); + + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id; + char* dst_buffer = data->MutableBuffer(&memory_type, &memory_type_id); + char* initial_state_buffer = + initial_state_it->second.data_->MutableBuffer( + &memory_type, &memory_type_id); + + memcpy( + dst_buffer, initial_state_buffer, + initial_state_it->second.data_->TotalByteSize()); + } else { + size_t state_size; + if (state.second.data_type() == inference::DataType::TYPE_STRING) { + auto element_count = triton::common::GetElementCount(dims); + // Total number of bytes required is equal to the element count + // multiplied by 4. + state_size = 4 * element_count; + } else { + state_size = + triton::common::GetByteSize(state.second.data_type(), dims); + } + data = std::make_shared( + state_size, TRITONSERVER_MEMORY_CPU, 0); + } + + const auto& input_pair = input_states_.emplace( + std::piecewise_construct, + std::forward_as_tuple(state_config.input_name()), + std::forward_as_tuple(new SequenceState( + state_config.input_name(), state.second.data_type(), dims))); + + if (!input_pair.second) { + LOG_WARNING + << "Detected duplicate 'input_name' in the state configuration: '" + << state_config.input_name() + << ".' This state configuration will be ignored."; + continue; + } + + auto& input_tensor = input_pair.first->second; + RETURN_IF_ERROR(input_tensor->SetData(data)); + if (input_tensor->DType() == inference::DataType::TYPE_STRING) { + RETURN_IF_ERROR(input_tensor->SetStringDataToZero()); + } + + const auto& output_pair = output_states_.emplace( + std::piecewise_construct, + std::forward_as_tuple(state_config.output_name()), + std::forward_as_tuple()); + if (!output_pair.second) { + // Remove the corresponding state from the input_states_map + input_states_.erase(state_config.input_name()); + LOG_WARNING << "Detected duplicate 'output_name' in the state " + "configuration: '" + << state_config.output_name() + << "'. This state configuration will be ignored."; + + continue; + } + } + + return Status::Success; +} + +Status +SequenceStates::OutputState( + const std::string& name, const inference::DataType datatype, + const int64_t* shape, const uint64_t dim_count, + SequenceState** output_state) +{ + const auto& output_state_itr = output_states_.find(name); + + // If the state name is not valid return an error. + if (output_state_itr == output_states_.end()) { + return Status( + Status::Code::INVALID_ARG, + "state '" + name + "' is not a valid state name."); + } + + if (output_states_[name] == nullptr) { + output_states_[name] = std::unique_ptr( + new SequenceState(name, datatype, shape, dim_count)); + } else { + // A new SequenceState is created here in case the shape for the new output + // state is different from the shape of the originally stored state. + std::unique_ptr output_state( + new SequenceState(name, datatype, shape, dim_count)); + + // Transfer the previously allocated buffer to the new output_state. + output_state->SetData(output_states_[name]->Data()); + output_states_[name] = std::move(output_state); + } + + auto& output_state_r = output_states_[name]; + size_t iter_advance = + std::distance(output_states_.begin(), output_states_.find(name)); + + // Find the input state corresponding to this output state. + auto input_states_itr = input_states_.begin(); + std::advance(input_states_itr, iter_advance); + auto& input_state_r = input_states_[input_states_itr->first]; + + if (output_state != nullptr) { + *output_state = output_states_[name].get(); + } + + output_state_r->SetStateUpdateCallback([&output_state_r, &input_state_r]() { + // Swap the internal memory if the size of the input and output state is + // equal + + if (output_state_r->Data()->TotalByteSize() == + input_state_r->Data()->TotalByteSize()) { + std::shared_ptr temp_memory = input_state_r->Data(); + RETURN_IF_ERROR(input_state_r->RemoveAllData()); + RETURN_IF_ERROR(input_state_r->SetData(output_state_r->Data())); + RETURN_IF_ERROR(output_state_r->RemoveAllData()); + RETURN_IF_ERROR(output_state_r->SetData(temp_memory)); + } else { + // If the size of output state is different from the input state, allocate + // a new memory for the input state with the same size as output state. + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id; + + const std::shared_ptr& input_memory = + reinterpret_cast&>( + input_state_r->Data()); + + input_memory->MutableBuffer(&memory_type, &memory_type_id); + std::shared_ptr memory = + std::make_shared( + output_state_r->Data()->TotalByteSize(), memory_type, + memory_type_id); + RETURN_IF_ERROR(input_state_r->RemoveAllData()); + RETURN_IF_ERROR(input_state_r->SetData(output_state_r->Data())); + RETURN_IF_ERROR(output_state_r->RemoveAllData()); + RETURN_IF_ERROR(output_state_r->SetData(memory)); + } + + // Update the shape and data type of the output state if it doesn't match + // the input state. + if (input_state_r->Shape() != output_state_r->Shape()) { + *input_state_r->MutableShape() = output_state_r->Shape(); + } + + if (input_state_r->DType() != output_state_r->DType()) { + *input_state_r->MutableDType() = output_state_r->DType(); + } + + return Status::Success; + }); + + return Status::Success; +} + +Status +SequenceStates::OutputState( + const std::string& name, const inference::DataType datatype, + const std::vector& shape, SequenceState** output_state) +{ + return OutputState(name, datatype, shape.data(), shape.size(), output_state); +} + +std::shared_ptr +SequenceStates::CopyAsNull(const std::shared_ptr& from) +{ + std::shared_ptr lsequence_states; + if (from != nullptr) { + lsequence_states.reset(new SequenceStates); + for (auto& from_input_state : from->InputStates()) { + auto& from_input_state_tensor = from_input_state.second; + const auto& input_pair = lsequence_states->input_states_.emplace( + std::piecewise_construct, + std::forward_as_tuple(from_input_state_tensor->Name()), + std::forward_as_tuple(new SequenceState( + from_input_state_tensor->Name(), from_input_state_tensor->DType(), + from_input_state_tensor->Shape()))); + + auto& input_tensor = input_pair.first->second; + std::shared_ptr data; + if (from_input_state_tensor->DType() == + inference::DataType::TYPE_STRING) { + // Use all-zero input states for null requests. + auto element_count = + triton::common::GetElementCount(from_input_state_tensor->Shape()); + auto state_size = 4 * element_count; + data = std::make_shared( + state_size, TRITONSERVER_MEMORY_CPU, 0); + } else { + data = std::make_shared( + from_input_state_tensor->Data()->TotalByteSize(), + TRITONSERVER_MEMORY_CPU, 0); + } + + input_tensor->SetData(data); + if (input_tensor->DType() == inference::DataType::TYPE_STRING) { + input_tensor->SetStringDataToZero(); + } + } + + for (auto& from_output_state : from->OutputStates()) { + lsequence_states->output_states_.emplace( + std::piecewise_construct, + std::forward_as_tuple(from_output_state.first), + std::forward_as_tuple()); + } + } + return lsequence_states; +} +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/sequence_state.h b/3rdparty/core-r22.12/src/sequence_state.h new file mode 100644 index 0000000000000000000000000000000000000000..a2d0b14244799b6e3b1c4aef931163315e7eb880 --- /dev/null +++ b/3rdparty/core-r22.12/src/sequence_state.h @@ -0,0 +1,168 @@ +// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include +#include +#include "memory.h" +#include "status.h" +#include "triton/common/model_config.h" + +#pragma once + +namespace triton { namespace core { + +// +// Sequence state tensors. +// +class SequenceState { + public: + SequenceState(); + SequenceState( + const std::string& name, const inference::DataType datatype, + const std::vector& shape); + SequenceState( + const std::string& name, const inference::DataType datatype, + const int64_t* shape, const uint64_t dim_count); + + // The name of the state tensor. + const std::string& Name() const { return name_; } + + // Data type of the state tensor. + inference::DataType DType() const { return datatype_; } + + // Mutable data type of the state tensor. + inference::DataType* MutableDType() { return &datatype_; } + + // The shape of the state tensor after normalization. + const std::vector& Shape() const { return shape_; } + std::vector* MutableShape() { return &shape_; } + + // The data for this shape. + std::shared_ptr& Data() { return data_; } + + // Set the data for this shape. Error if state already has some + // data. + Status SetData(const std::shared_ptr& data); + + // Sets state tensors that have type string to zero + Status SetStringDataToZero(); + + // Remove all existing data for the state. + Status RemoveAllData(); + + // Set the state update callback. + void SetStateUpdateCallback(std::function&& state_update_cb) + { + state_update_cb_ = std::move(state_update_cb); + } + + // Call the state update callback. This function will be called when + // TRITONBACKEND_StateUpdate is called. + Status Update() { return state_update_cb_(); } + + private: + DISALLOW_COPY_AND_ASSIGN(SequenceState); + std::string name_; + inference::DataType datatype_; + std::vector shape_; + std::vector batch_dim_; + std::shared_ptr data_; + std::function state_update_cb_ = []() { + // By default calling the TRITONBACKEND_StateUpdate will return an error. + return Status( + Status::Code::INVALID_ARG, + "TRITONBACKEND_StateUpdate called when sequence batching is disabled " + "or the 'states' section of the model configuration is empty."); + }; +}; + +class SequenceStates { + public: + struct InitialStateData { + InitialStateData(const std::string& state_init_name) + : state_init_name_(state_init_name) + { + } + + std::string state_init_name_; + std::shared_ptr data_; + }; + + // Initialize the state tensors according to the state model configuration. + // Will use a default value of 1 for the variable dimensions in the state + // tensor configuration. + Status Initialize( + const std::unordered_map< + std::string, const inference::ModelSequenceBatching_State&>& + state_output_config_map, + const size_t max_batch_size, + const std::unordered_map& initial_state); + + // Get a buffer holding the output state. + Status OutputState( + const std::string& name, const inference::DataType datatype, + const int64_t* shape, const uint64_t dim_count, + SequenceState** output_state); + Status OutputState( + const std::string& name, const inference::DataType datatype, + const std::vector& shape, SequenceState** output_state); + + // Create a copy of the 'from' sequence states for NULL requests. + static std::shared_ptr CopyAsNull( + const std::shared_ptr& from); + + const std::map>& InputStates() + { + return input_states_; + } + + std::map>& OutputStates() + { + return output_states_; + } + + void SetNullSequenceStates(std::shared_ptr sequence_states) + { + null_sequence_states_ = sequence_states; + is_null_request_ = true; + } + + const std::shared_ptr& NullSequenceStates() + { + return null_sequence_states_; + } + + bool IsNullRequest() { return is_null_request_; } + + private: + std::map> input_states_; + std::map> output_states_; + std::shared_ptr null_sequence_states_; + bool is_null_request_ = false; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/server.cc b/3rdparty/core-r22.12/src/server.cc new file mode 100644 index 0000000000000000000000000000000000000000..e313a16e2a2488b0478403a4b97849eb11bd3fed --- /dev/null +++ b/3rdparty/core-r22.12/src/server.cc @@ -0,0 +1,653 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "server.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "backend_manager.h" +#include "constants.h" +#include "cuda_utils.h" +#include "model.h" +#include "model_config.pb.h" +#include "model_config_utils.h" +#include "model_repository_manager.h" +#include "pinned_memory_manager.h" +#include "repo_agent.h" +#include "triton/common/async_work_queue.h" +#include "triton/common/logging.h" +#include "triton/common/model_config.h" +#include "triton/common/table_printer.h" + +#ifdef TRITON_ENABLE_GPU +#include "cuda_memory_manager.h" +#endif // TRITON_ENABLE_GPU + +namespace triton { namespace core { + +namespace { + +// Scoped increment / decrement of atomic +class ScopedAtomicIncrement { + public: + explicit ScopedAtomicIncrement(std::atomic& counter) + : counter_(counter) + { + counter_++; + } + + ~ScopedAtomicIncrement() { counter_--; } + + private: + std::atomic& counter_; +}; + +} // namespace + +// +// InferenceServer +// +InferenceServer::InferenceServer() + : version_(TRITON_VERSION), ready_state_(ServerReadyState::SERVER_INVALID) +{ + id_ = "triton"; + extensions_.push_back("classification"); + extensions_.push_back("sequence"); + extensions_.push_back("model_repository"); + extensions_.push_back("model_repository(unload_dependents)"); + extensions_.push_back("schedule_policy"); + extensions_.push_back("model_configuration"); + extensions_.push_back("system_shared_memory"); + extensions_.push_back("cuda_shared_memory"); + extensions_.push_back("binary_tensor_data"); +#ifdef TRITON_ENABLE_STATS + extensions_.push_back("statistics"); +#endif // TRITON_ENABLE_STATS +#ifdef TRITON_ENABLE_TRACING + extensions_.push_back("trace"); +#endif // TRITON_ENABLE_TRACING +#ifdef TRITON_ENABLE_LOGGING + extensions_.push_back("logging"); +#endif // TRITON_ENABLE_LOGGING + strict_model_config_ = true; + strict_readiness_ = true; + exit_timeout_secs_ = 30; + pinned_memory_pool_size_ = 1 << 28; + buffer_manager_thread_count_ = 0; + model_load_thread_count_ = + std::max(2u, 2 * std::thread::hardware_concurrency()); + +#ifdef TRITON_ENABLE_GPU + min_supported_compute_capability_ = TRITON_MIN_COMPUTE_CAPABILITY; +#else + min_supported_compute_capability_ = 0.0; +#endif // TRITON_ENABLE_GPU + + inflight_request_counter_ = 0; +} + +Status +InferenceServer::Init() +{ + Status status; + + ready_state_ = ServerReadyState::SERVER_INITIALIZING; + + if (model_repository_paths_.empty()) { + ready_state_ = ServerReadyState::SERVER_FAILED_TO_INITIALIZE; + return Status( + Status::Code::INVALID_ARG, "--model-repository must be specified"); + } + + if (repoagent_dir_.empty()) { + ready_state_ = ServerReadyState::SERVER_FAILED_TO_INITIALIZE; + return Status( + Status::Code::INVALID_ARG, "--repoagent-directory can not be empty"); + } + + status = TritonRepoAgentManager::SetGlobalSearchPath(repoagent_dir_); + if (!status.IsOk()) { + ready_state_ = ServerReadyState::SERVER_FAILED_TO_INITIALIZE; + return status; + } + + status = TritonBackendManager::Create(&backend_manager_); + if (!status.IsOk()) { + ready_state_ = ServerReadyState::SERVER_FAILED_TO_INITIALIZE; + return status; + } + + if (buffer_manager_thread_count_ > 0) { + status = CommonErrorToStatus(triton::common::AsyncWorkQueue::Initialize( + buffer_manager_thread_count_)); + if (!status.IsOk()) { + ready_state_ = ServerReadyState::SERVER_FAILED_TO_INITIALIZE; + return status; + } + } + + std::unique_ptr local_rate_limiter; + bool ignore_resources_and_priority = + (rate_limit_mode_ == RateLimitMode::RL_OFF); + + status = RateLimiter::Create( + ignore_resources_and_priority, rate_limit_resource_map_, + &local_rate_limiter); + rate_limiter_ = std::move(local_rate_limiter); + if (!status.IsOk()) { + ready_state_ = ServerReadyState::SERVER_FAILED_TO_INITIALIZE; + return status; + } + + PinnedMemoryManager::Options options(pinned_memory_pool_size_); + status = PinnedMemoryManager::Create(options); + if (!status.IsOk()) { + ready_state_ = ServerReadyState::SERVER_FAILED_TO_INITIALIZE; + return status; + } + + if (response_cache_byte_size_ > 0) { + std::unique_ptr local_response_cache; + status = RequestResponseCache::Create( + response_cache_byte_size_, &local_response_cache); + if (!status.IsOk()) { + ready_state_ = ServerReadyState::SERVER_FAILED_TO_INITIALIZE; + return status; + } + + response_cache_ = std::move(local_response_cache); + } + + +#ifdef TRITON_ENABLE_GPU + // Set the default CUDA memory pool size for GPUs where it is not + // set explicitly. + std::set supported_gpus; + if (GetSupportedGPUs(&supported_gpus, min_supported_compute_capability_) + .IsOk()) { + for (const auto gpu : supported_gpus) { + if (cuda_memory_pool_size_.find(gpu) == cuda_memory_pool_size_.end()) { + cuda_memory_pool_size_[gpu] = 1 << 26; + } + } + } + + CudaMemoryManager::Options cuda_options( + min_supported_compute_capability_, cuda_memory_pool_size_); + status = CudaMemoryManager::Create(cuda_options); + // If CUDA memory manager can't be created, just log error as the + // server can still function properly + if (!status.IsOk()) { + LOG_ERROR << status.Message(); + } +#endif // TRITON_ENABLE_GPU + + status = EnablePeerAccess(min_supported_compute_capability_); + if (!status.IsOk()) { + // failed to enable peer access is not critical, just inefficient. + LOG_WARNING << status.Message(); + } + + // Create the model manager for the repository. Unless model control + // is disabled, all models are eagerly loaded when the manager is created. + bool polling_enabled = (model_control_mode_ == ModelControlMode::MODE_POLL); + bool model_control_enabled = + (model_control_mode_ == ModelControlMode::MODE_EXPLICIT); + const ModelLifeCycleOptions life_cycle_options( + min_supported_compute_capability_, backend_cmdline_config_map_, + host_policy_map_, model_load_thread_count_); + status = ModelRepositoryManager::Create( + this, version_, model_repository_paths_, startup_models_, + strict_model_config_, polling_enabled, model_control_enabled, + life_cycle_options, &model_repository_manager_); + if (!status.IsOk()) { + if (model_repository_manager_ == nullptr) { + ready_state_ = ServerReadyState::SERVER_FAILED_TO_INITIALIZE; + } else { + // If error is returned while the manager is set, we assume the + // failure is due to a model not loading correctly so we just + // continue if not exiting on error. + ready_state_ = ServerReadyState::SERVER_READY; + PrintBackendAndModelSummary(); + } + } else { + ready_state_ = ServerReadyState::SERVER_READY; + PrintBackendAndModelSummary(); + } + + return status; +} + +Status +InferenceServer::Stop(const bool force) +{ + if (!force && (ready_state_ != ServerReadyState::SERVER_READY)) { + return Status::Success; + } + + ready_state_ = ServerReadyState::SERVER_EXITING; + + if (model_repository_manager_ == nullptr) { + LOG_INFO << "No server context available. Exiting immediately."; + return Status::Success; + } else { + LOG_INFO << "Waiting for in-flight requests to complete."; + } + + Status status = model_repository_manager_->StopAllModels(); + if (!status.IsOk()) { + LOG_ERROR << status.Message(); + } + + // Wait for all in-flight non-inference requests to complete and all + // loaded models to unload, or for the exit timeout to expire. + uint32_t exit_timeout_iters = exit_timeout_secs_; + bool unloading_model = false; + while (true) { + if (!unloading_model) { + // Check if all in-flight inference requests / sequences are completed + const auto& inflight_status = model_repository_manager_->InflightStatus(); + LOG_INFO << "Timeout " << exit_timeout_iters << ": Found " + << inflight_status.size() + << " model versions that have in-flight inferences"; + for (const auto& inflight : inflight_status) { + LOG_INFO << "Model '" << std::get<0>(inflight) << "' " + << "(version " << std::get<1>(inflight) << ") has " + << std::get<2>(inflight) << " in-flight inferences"; + } + + if (inflight_status.size() == 0) { + unloading_model = true; + status = model_repository_manager_->UnloadAllModels(); + if (!status.IsOk()) { + LOG_ERROR << status.Message(); + } else { + LOG_INFO << "All models are stopped, unloading models"; + continue; + } + } + } else { + const auto& live_models = model_repository_manager_->LiveModelStates(); + + LOG_INFO << "Timeout " << exit_timeout_iters << ": Found " + << live_models.size() << " live models and " + << inflight_request_counter_ + << " in-flight non-inference requests"; + if (LOG_VERBOSE_IS_ON(1)) { + for (const auto& m : live_models) { + for (const auto& v : m.second) { + LOG_VERBOSE(1) << m.first << " v" << v.first << ": " + << ModelReadyStateString(v.second.first); + } + } + } + + if ((live_models.size() == 0) && (inflight_request_counter_ == 0)) { + return Status::Success; + } + } + if (exit_timeout_iters <= 0) { + break; + } + + exit_timeout_iters--; + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + + return Status( + Status::Code::INTERNAL, "Exit timeout expired. Exiting immediately."); +} + +Status +InferenceServer::PollModelRepository() +{ + LOG_VERBOSE(1) << "Polling model repository"; + + // Look for changes and update the loaded model configurations + // appropriately. + if (ready_state_ == ServerReadyState::SERVER_READY) { + ScopedAtomicIncrement inflight(inflight_request_counter_); + RETURN_IF_ERROR(model_repository_manager_->PollAndUpdate()); + } + + return Status::Success; +} + +Status +InferenceServer::IsLive(bool* live) +{ + *live = false; + + if (ready_state_ == ServerReadyState::SERVER_EXITING) { + return Status(Status::Code::UNAVAILABLE, "Server exiting"); + } + + ScopedAtomicIncrement inflight(inflight_request_counter_); + + // Server is considered live if it can respond to this health + // request and it was able to initialize. + *live = + ((ready_state_ != ServerReadyState::SERVER_INVALID) && + (ready_state_ != ServerReadyState::SERVER_INITIALIZING) && + (ready_state_ != ServerReadyState::SERVER_FAILED_TO_INITIALIZE)); + return Status::Success; +} + +Status +InferenceServer::IsReady(bool* ready) +{ + *ready = false; + + if (ready_state_ == ServerReadyState::SERVER_EXITING) { + return Status(Status::Code::UNAVAILABLE, "Server exiting"); + } + + ScopedAtomicIncrement inflight(inflight_request_counter_); + + // Server is considered ready if it is in the ready state. + // Additionally can report ready only when all models are ready. + *ready = (ready_state_ == ServerReadyState::SERVER_READY); + if (*ready && strict_readiness_) { + // Strict readiness... get the model status and make sure all + // models are ready. + const auto model_versions = model_repository_manager_->ModelStates(); + + for (const auto& mv : model_versions) { + // If a model status is present but no version status, + // the model is not ready as there is no proper version to be served + if (mv.second.size() == 0) { + *ready = false; + goto strict_done; + } + for (const auto& vs : mv.second) { + // Okay if model is not ready due to unload + if ((vs.second.first != ModelReadyState::READY) && + (vs.second.second != "unloaded")) { + *ready = false; + goto strict_done; + } + } + } + strict_done:; + } + + return Status::Success; +} + +Status +InferenceServer::ModelIsReady( + const std::string& model_name, const int64_t model_version, bool* ready) +{ + *ready = false; + + if (ready_state_ != ServerReadyState::SERVER_READY) { + return Status(Status::Code::UNAVAILABLE, "Server not ready"); + } + + ScopedAtomicIncrement inflight(inflight_request_counter_); + + std::shared_ptr model; + if (GetModel(model_name, model_version, &model).IsOk()) { + ModelReadyState state; + if (model_repository_manager_ + ->ModelState(model_name, model->Version(), &state) + .IsOk()) { + *ready = (state == ModelReadyState::READY); + } + } + + return Status::Success; +} + +Status +InferenceServer::ModelReadyVersions( + const std::string& model_name, std::vector* versions) +{ + if (ready_state_ != ServerReadyState::SERVER_READY) { + return Status(Status::Code::UNAVAILABLE, "Server not ready"); + } + + ScopedAtomicIncrement inflight(inflight_request_counter_); + + const auto version_states = + model_repository_manager_->VersionStates(model_name); + for (const auto& pr : version_states) { + if (pr.second.first == ModelReadyState::READY) { + versions->push_back(pr.first); + } + } + + return Status::Success; +} + +Status +InferenceServer::ModelReadyVersions( + std::map>* ready_model_versions) +{ + if (ready_state_ != ServerReadyState::SERVER_READY) { + return Status(Status::Code::UNAVAILABLE, "Server not ready"); + } + + ScopedAtomicIncrement inflight(inflight_request_counter_); + + const auto model_versions = + model_repository_manager_->LiveModelStates(true /* strict_readiness */); + + ready_model_versions->clear(); + std::vector versions; + for (const auto& mv_pair : model_versions) { + for (const auto& vs_pair : mv_pair.second) { + versions.emplace_back(vs_pair.first); + } + ready_model_versions->emplace(mv_pair.first, std::move(versions)); + } + + return Status::Success; +} + +Status +InferenceServer::RepositoryIndex( + const bool ready_only, + std::vector* index) +{ + if (ready_state_ != ServerReadyState::SERVER_READY) { + return Status(Status::Code::UNAVAILABLE, "Server not ready"); + } + + ScopedAtomicIncrement inflight(inflight_request_counter_); + + return model_repository_manager_->RepositoryIndex(ready_only, index); +} + +Status +InferenceServer::InferAsync(std::unique_ptr& request) +{ + // Allow inference request while server exiting to provide graceful + // completion of inference sequence that spans multiple requests. + if ((ready_state_ != ServerReadyState::SERVER_READY) && + (ready_state_ != ServerReadyState::SERVER_EXITING)) { + return Status(Status::Code::UNAVAILABLE, "Server not ready"); + } + +#ifdef TRITON_ENABLE_STATS + request->CaptureRequestStartNs(); + INFER_TRACE_ACTIVITY( + request->Trace(), TRITONSERVER_TRACE_REQUEST_START, + request->RequestStartNs()); +#endif // TRITON_ENABLE_STATS + + return InferenceRequest::Run(request); +} + +Status +InferenceServer::LoadModel( + const std::unordered_map< + std::string, std::vector>& models) +{ + if (ready_state_ != ServerReadyState::SERVER_READY) { + return Status(Status::Code::UNAVAILABLE, "Server not ready"); + } + + ScopedAtomicIncrement inflight(inflight_request_counter_); + + auto action_type = ActionType::LOAD; + return model_repository_manager_->LoadUnloadModel( + models, action_type, false /* unload_dependents */); +} + +Status +InferenceServer::UnloadModel( + const std::string& model_name, const bool unload_dependents) +{ + if (ready_state_ != ServerReadyState::SERVER_READY) { + return Status(Status::Code::UNAVAILABLE, "Server not ready"); + } + + ScopedAtomicIncrement inflight(inflight_request_counter_); + + auto action_type = ActionType::UNLOAD; + return model_repository_manager_->LoadUnloadModel( + {{model_name, {}}}, action_type, unload_dependents); +} + +Status +InferenceServer::PrintBackendAndModelSummary() +{ + // Repository Agents Summary + std::vector repoagent_headers; + repoagent_headers.emplace_back("Repository Agent"); + repoagent_headers.emplace_back("Path"); + + triton::common::TablePrinter repoagents_table(repoagent_headers); + + std::unique_ptr> repoagent_state; + RETURN_IF_ERROR(TritonRepoAgentManager::AgentState(&repoagent_state)); + + for (const auto& repoagent_pair : *repoagent_state) { + std::vector repoagent_record; + repoagent_record.emplace_back(repoagent_pair.first); + repoagent_record.emplace_back(repoagent_pair.second); + repoagents_table.InsertRow(repoagent_record); + } + std::string repoagents_table_string = repoagents_table.PrintTable(); + LOG_INFO << repoagents_table_string; + + // Backends Summary + std::vector backend_headers; + backend_headers.emplace_back("Backend"); + backend_headers.emplace_back("Path"); + backend_headers.emplace_back("Config"); + + triton::common::TablePrinter backends_table(backend_headers); + + std::unique_ptr>> + backend_state; + RETURN_IF_ERROR(backend_manager_->BackendState(&backend_state)); + + for (const auto& backend_pair : *backend_state) { + std::vector backend_record; + + // Backend Name + backend_record.emplace_back(backend_pair.first); + + // Backend config and lib path + for (const auto& backend_field : backend_pair.second) { + backend_record.emplace_back(backend_field); + } + backends_table.InsertRow(backend_record); + } + std::string backends_table_string = backends_table.PrintTable(); + LOG_INFO << backends_table_string; + + // Models Summary + auto model_states = model_repository_manager_->ModelStates(); + + std::vector model_headers; + model_headers.emplace_back("Model"); + model_headers.emplace_back("Version"); + model_headers.emplace_back("Status"); + + triton::common::TablePrinter models_table(model_headers); + + for (const auto& model_state : model_states) { + auto model_version_map = model_state.second; + std::string model_name = model_state.first; + + // If model_version_map size is zero, no version is found for this model + if (model_version_map.size() == 0) { + std::vector model_record; + model_record.emplace_back(model_name); + model_record.emplace_back("-"); + model_record.emplace_back("Not loaded: No model version was found"); + models_table.InsertRow(model_record); + } else { + for (const auto& model_map : model_version_map) { + std::vector model_record; + std::string model_version = std::to_string(model_map.first); + auto model_status_pair = model_map.second; + std::string model_status = + ModelReadyStateString(model_status_pair.first); + + if (model_status_pair.second != "") { + model_status += ": " + model_status_pair.second; + } + + model_record.emplace_back(model_name); + model_record.emplace_back(model_version); + model_record.emplace_back(model_status); + models_table.InsertRow(model_record); + } + } + } + std::string models_table_string = models_table.PrintTable(); + LOG_INFO << models_table_string; + + return Status::Success; +} + +Status +InferenceServer::RegisterModelRepository( + const std::string& repository, + const std::unordered_map& model_mapping) +{ + return model_repository_manager_->RegisterModelRepository( + repository, model_mapping); +} + +Status +InferenceServer::UnregisterModelRepository(const std::string& repository) +{ + return model_repository_manager_->UnregisterModelRepository(repository); +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/server.h b/3rdparty/core-r22.12/src/server.h new file mode 100644 index 0000000000000000000000000000000000000000..f1e3dab0af3f1cd9c767b304c6b475bb03d48300 --- /dev/null +++ b/3rdparty/core-r22.12/src/server.h @@ -0,0 +1,326 @@ +// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "backend_manager.h" +#include "infer_parameter.h" +#include "model_config.pb.h" +#include "model_repository_manager.h" +#include "rate_limiter.h" +#include "response_cache.h" +#include "status.h" +#include "triton/common/model_config.h" + +namespace triton { namespace core { + +class Model; +class InferenceRequest; + +enum class ModelControlMode { MODE_NONE, MODE_POLL, MODE_EXPLICIT }; + +enum class RateLimitMode { RL_EXEC_COUNT, RL_OFF }; + +// Readiness status for the inference server. +enum class ServerReadyState { + // The server is in an invalid state and will likely not response + // correctly to any requests. + SERVER_INVALID, + + // The server is initializing. + SERVER_INITIALIZING, + + // The server is ready and accepting requests. + SERVER_READY, + + // The server is exiting and will not respond to requests. + SERVER_EXITING, + + // The server did not initialize correctly. + SERVER_FAILED_TO_INITIALIZE +}; + +// Inference server information. +class InferenceServer { + public: + // Construct an inference server. + InferenceServer(); + + // Initialize the server. Return true on success, false otherwise. + Status Init(); + + // Stop the server. Return true if all models are unloaded, false + // if exit timeout occurs. If 'force' is true attempt to stop the + // server even if it is not in a ready state. + Status Stop(const bool force = false); + + // Check the model repository for changes and update server state + // based on those changes. + Status PollModelRepository(); + + // Server health + Status IsLive(bool* live); + Status IsReady(bool* ready); + + // Model health + Status ModelIsReady( + const std::string& model_name, const int64_t model_version, bool* ready); + + // Return the ready versions of specific model + Status ModelReadyVersions( + const std::string& model_name, std::vector* versions); + + // Return the ready versions of all models + Status ModelReadyVersions( + std::map>* model_versions); + + /// Get the index of all models in all repositories. + /// \param ready_only If true return only index of models that are ready. + /// \param index Returns the index. + /// \return error status. + Status RepositoryIndex( + const bool ready_only, + std::vector* index); + + // Inference. If Status::Success is returned then this function has + // taken ownership of the request object and so 'request' will be + // nullptr. If non-success is returned then the caller still retains + // ownership of 'request'. + Status InferAsync(std::unique_ptr& request); + + // Load the corresponding model. Reload the model if it has been loaded. + Status LoadModel( + const std::unordered_map< + std::string, std::vector>& models); + + // Unload the corresponding model. + Status UnloadModel( + const std::string& model_name, const bool unload_dependents); + + // Print backends and models summary + Status PrintBackendAndModelSummary(); + + // Register model repository path and associated mappings + Status RegisterModelRepository( + const std::string& repository, + const std::unordered_map& model_mapping); + + // Unregister model repository path. + Status UnregisterModelRepository(const std::string& repository); + + // Return the server version. + const std::string& Version() const { return version_; } + + // Return the server extensions. + const std::vector& Extensions() const { return extensions_; } + + // Get / set the ID of the server. + const std::string& Id() const { return id_; } + void SetId(const std::string& id) { id_ = id; } + + // Get / set the model repository path + const std::set& ModelRepositoryPaths() const + { + return model_repository_paths_; + } + + void SetModelRepositoryPaths(const std::set& p) + { + model_repository_paths_ = p; + } + + // Get / set model control mode. + ModelControlMode GetModelControlMode() const { return model_control_mode_; } + void SetModelControlMode(ModelControlMode m) { model_control_mode_ = m; } + + // Get / set the startup models + const std::set& StartupModels() const { return startup_models_; } + void SetStartupModels(const std::set& m) { startup_models_ = m; } + + // Get / set strict model configuration enable. + bool StrictModelConfigEnabled() const { return strict_model_config_; } + void SetStrictModelConfigEnabled(bool e) { strict_model_config_ = e; } + + // Get / set rate limiter mode. + RateLimitMode RateLimiterMode() const { return rate_limit_mode_; } + void SetRateLimiterMode(RateLimitMode m) { rate_limit_mode_ = m; } + + // Get / set rate limit resource counts + const RateLimiter::ResourceMap& RateLimiterResources() const + { + return rate_limit_resource_map_; + } + void SetRateLimiterResources(const RateLimiter::ResourceMap& rm) + { + rate_limit_resource_map_ = rm; + } + + // Get / set the pinned memory pool byte size. + int64_t PinnedMemoryPoolByteSize() const { return pinned_memory_pool_size_; } + void SetPinnedMemoryPoolByteSize(int64_t s) + { + pinned_memory_pool_size_ = std::max((int64_t)0, s); + } + + // Get / set the response cache byte size. + uint64_t ResponseCacheByteSize() const { return response_cache_byte_size_; } + void SetResponseCacheByteSize(uint64_t s) + { + response_cache_byte_size_ = s; + response_cache_enabled_ = (s > 0) ? true : false; + } + + bool ResponseCacheEnabled() const { return response_cache_enabled_; } + + // Get / set CUDA memory pool size + const std::map& CudaMemoryPoolByteSize() const + { + return cuda_memory_pool_size_; + } + + void SetCudaMemoryPoolByteSize(const std::map& s) + { + cuda_memory_pool_size_ = s; + } + + // Get / set the minimum support CUDA compute capability. + double MinSupportedComputeCapability() const + { + return min_supported_compute_capability_; + } + void SetMinSupportedComputeCapability(double c) + { + min_supported_compute_capability_ = c; + } + + // Get / set strict readiness enable. + bool StrictReadinessEnabled() const { return strict_readiness_; } + void SetStrictReadinessEnabled(bool e) { strict_readiness_ = e; } + + // Get / set the server exit timeout, in seconds. + int32_t ExitTimeoutSeconds() const { return exit_timeout_secs_; } + void SetExitTimeoutSeconds(int32_t s) { exit_timeout_secs_ = std::max(0, s); } + + void SetBufferManagerThreadCount(unsigned int c) + { + buffer_manager_thread_count_ = c; + } + + void SetModelLoadThreadCount(unsigned int c) { model_load_thread_count_ = c; } + + // Set a backend command-line configuration + void SetBackendCmdlineConfig( + const triton::common::BackendCmdlineConfigMap& bc) + { + backend_cmdline_config_map_ = bc; + } + + void SetHostPolicyCmdlineConfig( + const triton::common::HostPolicyCmdlineConfigMap& hp) + { + host_policy_map_ = hp; + } + + void SetRepoAgentDir(const std::string& d) { repoagent_dir_ = d; } + + // Return the requested model object. + Status GetModel( + const std::string& model_name, const int64_t model_version, + std::shared_ptr* model) + { + // Allow model retrival while server exiting to provide graceful + // completion of inference sequence that spans multiple requests. + if ((ready_state_ != ServerReadyState::SERVER_READY) && + (ready_state_ != ServerReadyState::SERVER_EXITING)) { + return Status(Status::Code::UNAVAILABLE, "Server not ready"); + } + return model_repository_manager_->GetModel( + model_name, model_version, model); + } + + // Get the Backend Manager + const std::shared_ptr& BackendManager() + { + return backend_manager_; + } + + // Return the pointer to RateLimiter object. + std::shared_ptr GetRateLimiter() { return rate_limiter_; } + + // Return the pointer to response cache object. + std::shared_ptr GetResponseCache() + { + return response_cache_; + } + + private: + const std::string version_; + std::string id_; + std::vector extensions_; + + std::set model_repository_paths_; + std::set startup_models_; + ModelControlMode model_control_mode_; + bool strict_model_config_; + bool strict_readiness_; + uint32_t exit_timeout_secs_; + uint32_t buffer_manager_thread_count_; + uint32_t model_load_thread_count_; + uint64_t pinned_memory_pool_size_; + uint64_t response_cache_byte_size_; + bool response_cache_enabled_; + std::map cuda_memory_pool_size_; + double min_supported_compute_capability_; + triton::common::BackendCmdlineConfigMap backend_cmdline_config_map_; + triton::common::HostPolicyCmdlineConfigMap host_policy_map_; + std::string repoagent_dir_; + RateLimitMode rate_limit_mode_; + RateLimiter::ResourceMap rate_limit_resource_map_; + + + // Current state of the inference server. + ServerReadyState ready_state_; + + // Number of in-flight, non-inference requests. During shutdown we + // attempt to wait for all in-flight non-inference requests to + // complete before exiting (also wait for in-flight inference + // requests but that is determined by model shared_ptr). + std::atomic inflight_request_counter_; + + std::shared_ptr rate_limiter_; + std::unique_ptr model_repository_manager_; + std::shared_ptr backend_manager_; + std::shared_ptr response_cache_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/server_message.h b/3rdparty/core-r22.12/src/server_message.h new file mode 100644 index 0000000000000000000000000000000000000000..ae5d0668e69fdaadc573cf8340a80e047d6914c8 --- /dev/null +++ b/3rdparty/core-r22.12/src/server_message.h @@ -0,0 +1,90 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include "status.h" + +#define TRITONJSON_STATUSTYPE triton::core::Status +#define TRITONJSON_STATUSRETURN(M) \ + return triton::core::Status(triton::core::Status::Code::INTERNAL, (M)) +#define TRITONJSON_STATUSSUCCESS triton::core::Status::Success +#include "triton/common/triton_json.h" + +namespace triton { namespace core { + +// +// Implementation for TRITONSERVER_Message. +// +class TritonServerMessage { + public: + TritonServerMessage(const triton::common::TritonJson::Value& msg) + { + json_buffer_.Clear(); + msg.Write(&json_buffer_); + base_ = json_buffer_.Base(); + byte_size_ = json_buffer_.Size(); + from_json_ = true; + } + + TritonServerMessage(std::string&& msg) + { + str_buffer_ = std::move(msg); + base_ = str_buffer_.data(); + byte_size_ = str_buffer_.size(); + from_json_ = false; + } + + TritonServerMessage(const TritonServerMessage& rhs) + { + from_json_ = rhs.from_json_; + if (from_json_) { + json_buffer_ = rhs.json_buffer_; + base_ = json_buffer_.Base(); + byte_size_ = json_buffer_.Size(); + } else { + str_buffer_ = rhs.str_buffer_; + base_ = str_buffer_.data(); + byte_size_ = str_buffer_.size(); + } + } + + void Serialize(const char** base, size_t* byte_size) const + { + *base = base_; + *byte_size = byte_size_; + } + + private: + bool from_json_; + triton::common::TritonJson::WriteBuffer json_buffer_; + std::string str_buffer_; + + const char* base_; + size_t byte_size_; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/shared_library.cc b/3rdparty/core-r22.12/src/shared_library.cc new file mode 100644 index 0000000000000000000000000000000000000000..2bf00b15e74eaa66f0c9f472a67be355ef1d93db --- /dev/null +++ b/3rdparty/core-r22.12/src/shared_library.cc @@ -0,0 +1,231 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "shared_library.h" + +#include "filesystem.h" +#include "mutex" +#include "triton/common/logging.h" + +#ifdef _WIN32 +// suppress the min and max definitions in Windef.h. +#define NOMINMAX +#include +#else +#include +#endif + +namespace triton { namespace core { + +static std::mutex mu_; + +Status +SharedLibrary::Acquire(std::unique_ptr* slib) +{ + mu_.lock(); + slib->reset(new SharedLibrary()); + return Status::Success; +} + +SharedLibrary::~SharedLibrary() +{ + mu_.unlock(); +} + +Status +SharedLibrary::SetLibraryDirectory(const std::string& path) +{ +#ifdef _WIN32 + LOG_VERBOSE(1) << "SetLibraryDirectory: path = " << path; + if (!SetDllDirectory(path.c_str())) { + LPSTR err_buffer = nullptr; + size_t size = FormatMessageA( + FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPSTR)&err_buffer, 0, NULL); + std::string errstr(err_buffer, size); + LocalFree(err_buffer); + + return Status( + Status::Code::NOT_FOUND, + "unable to set dll path " + path + ": " + errstr); + } +#endif + + return Status::Success; +} + +Status +SharedLibrary::ResetLibraryDirectory() +{ +#ifdef _WIN32 + LOG_VERBOSE(1) << "ResetLibraryDirectory"; + if (!SetDllDirectory(NULL)) { + LPSTR err_buffer = nullptr; + size_t size = FormatMessageA( + FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPSTR)&err_buffer, 0, NULL); + std::string errstr(err_buffer, size); + LocalFree(err_buffer); + + return Status( + Status::Code::NOT_FOUND, "unable to reset dll path: " + errstr); + } +#endif + + return Status::Success; +} + +Status +SharedLibrary::OpenLibraryHandle(const std::string& path, void** handle) +{ + LOG_VERBOSE(1) << "OpenLibraryHandle: " << path; + +#ifdef _WIN32 + // Need to put shared library directory on the DLL path so that any + // dependencies of the shared library are found + const std::string library_dir = DirName(path); + RETURN_IF_ERROR(SetLibraryDirectory(library_dir)); + + // HMODULE is typedef of void* + // https://docs.microsoft.com/en-us/windows/win32/winprog/windows-data-types + LOG_VERBOSE(1) << "OpenLibraryHandle: path = " << path; + *handle = LoadLibrary(path.c_str()); + + // Remove the dll path added above... do this unconditionally before + // check for failure in dll load. + RETURN_IF_ERROR(ResetLibraryDirectory()); + + if (*handle == nullptr) { + LPSTR err_buffer = nullptr; + size_t size = FormatMessageA( + FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPSTR)&err_buffer, 0, NULL); + std::string errstr(err_buffer, size); + LocalFree(err_buffer); + + return Status( + Status::Code::NOT_FOUND, "unable to load shared library: " + errstr); + } +#else + *handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL); + if (*handle == nullptr) { + return Status( + Status::Code::NOT_FOUND, + "unable to load shared library: " + std::string(dlerror())); + } +#endif + + return Status::Success; +} + +Status +SharedLibrary::CloseLibraryHandle(void* handle) +{ + if (handle != nullptr) { +#ifdef _WIN32 + if (FreeLibrary((HMODULE)handle) == 0) { + LPSTR err_buffer = nullptr; + size_t size = FormatMessageA( + FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPSTR)&err_buffer, 0, NULL); + std::string errstr(err_buffer, size); + LocalFree(err_buffer); + return Status( + Status::Code::INTERNAL, "unable to unload shared library: " + errstr); + } +#else + if (dlclose(handle) != 0) { + return Status( + Status::Code::INTERNAL, + "unable to unload shared library: " + std::string(dlerror())); + } +#endif + } + + return Status::Success; +} + +Status +SharedLibrary::GetEntrypoint( + void* handle, const std::string& name, const bool optional, void** befn) +{ + *befn = nullptr; + +#ifdef _WIN32 + void* fn = GetProcAddress((HMODULE)handle, name.c_str()); + if ((fn == nullptr) && !optional) { + LPSTR err_buffer = nullptr; + size_t size = FormatMessageA( + FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPSTR)&err_buffer, 0, NULL); + std::string errstr(err_buffer, size); + LocalFree(err_buffer); + return Status( + Status::Code::NOT_FOUND, + "unable to find '" + name + + "' entrypoint in custom library: " + errstr); + } +#else + dlerror(); + void* fn = dlsym(handle, name.c_str()); + const char* dlsym_error = dlerror(); + if (dlsym_error != nullptr) { + if (optional) { + return Status::Success; + } + + std::string errstr(dlsym_error); // need copy as dlclose overwrites + return Status( + Status::Code::NOT_FOUND, "unable to find required entrypoint '" + name + + "' in shared library: " + errstr); + } + + if (fn == nullptr) { + if (optional) { + return Status::Success; + } + + return Status( + Status::Code::NOT_FOUND, + "unable to find required entrypoint '" + name + "' in shared library"); + } +#endif + + *befn = fn; + return Status::Success; +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/shared_library.h b/3rdparty/core-r22.12/src/shared_library.h new file mode 100644 index 0000000000000000000000000000000000000000..8ab12f3a6b07e6fde21dc154bca2cf5870928cec --- /dev/null +++ b/3rdparty/core-r22.12/src/shared_library.h @@ -0,0 +1,72 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include "constants.h" +#include "status.h" + +namespace triton { namespace core { + +// SharedLibrary +// +// Utility functions for shared libraries. Because some operations +// require serialization, this object cannot be directly constructed +// and must instead be accessed using Acquire(). +class SharedLibrary { + public: + // Acquire a SharedLibrary object exclusively. Any other attempts to + // concurrently acquire a SharedLibrary object will block. + // object. Ownership is released by destroying the SharedLibrary + // object. + static Status Acquire(std::unique_ptr* slib); + + ~SharedLibrary(); + + // Configuration so that dependent libraries will be searched for in + // 'path' during OpenLibraryHandle. + Status SetLibraryDirectory(const std::string& path); + + // Reset any configuration done by SetLibraryDirectory. + Status ResetLibraryDirectory(); + + // Open shared library and return generic handle. + Status OpenLibraryHandle(const std::string& path, void** handle); + + // Close shared library. + Status CloseLibraryHandle(void* handle); + + // Get a generic pointer for an entrypoint into a shared library. + Status GetEntrypoint( + void* handle, const std::string& name, const bool optional, void** befn); + + private: + DISALLOW_COPY_AND_ASSIGN(SharedLibrary); + explicit SharedLibrary() = default; +}; + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/status.cc b/3rdparty/core-r22.12/src/status.cc new file mode 100644 index 0000000000000000000000000000000000000000..1640ee5ed08b34fa9af6bb29a27bb7b79258852a --- /dev/null +++ b/3rdparty/core-r22.12/src/status.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "status.h" + +namespace triton { namespace core { + +const Status Status::Success(Status::Code::SUCCESS); + +Status::Code +TritonCodeToStatusCode(TRITONSERVER_Error_Code code) +{ + switch (code) { + case TRITONSERVER_ERROR_UNKNOWN: + return Status::Code::UNKNOWN; + case TRITONSERVER_ERROR_INTERNAL: + return Status::Code::INTERNAL; + case TRITONSERVER_ERROR_NOT_FOUND: + return Status::Code::NOT_FOUND; + case TRITONSERVER_ERROR_INVALID_ARG: + return Status::Code::INVALID_ARG; + case TRITONSERVER_ERROR_UNAVAILABLE: + return Status::Code::UNAVAILABLE; + case TRITONSERVER_ERROR_UNSUPPORTED: + return Status::Code::UNSUPPORTED; + case TRITONSERVER_ERROR_ALREADY_EXISTS: + return Status::Code::ALREADY_EXISTS; + + default: + break; + } + + return Status::Code::UNKNOWN; +} + +TRITONSERVER_Error_Code +StatusCodeToTritonCode(Status::Code status_code) +{ + switch (status_code) { + case Status::Code::UNKNOWN: + return TRITONSERVER_ERROR_UNKNOWN; + case Status::Code::INTERNAL: + return TRITONSERVER_ERROR_INTERNAL; + case Status::Code::NOT_FOUND: + return TRITONSERVER_ERROR_NOT_FOUND; + case Status::Code::INVALID_ARG: + return TRITONSERVER_ERROR_INVALID_ARG; + case Status::Code::UNAVAILABLE: + return TRITONSERVER_ERROR_UNAVAILABLE; + case Status::Code::UNSUPPORTED: + return TRITONSERVER_ERROR_UNSUPPORTED; + case Status::Code::ALREADY_EXISTS: + return TRITONSERVER_ERROR_ALREADY_EXISTS; + + default: + break; + } + + return TRITONSERVER_ERROR_UNKNOWN; +} + +Status +CommonErrorToStatus(const triton::common::Error& error) +{ + return Status(error); +} + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/status.h b/3rdparty/core-r22.12/src/status.h new file mode 100644 index 0000000000000000000000000000000000000000..6efdf1522336c44a94d63c618e1be8dbf3091d4c --- /dev/null +++ b/3rdparty/core-r22.12/src/status.h @@ -0,0 +1,97 @@ +// Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include "triton/common/error.h" +#include "tritonserver_apis.h" + +namespace triton { namespace core { + +class Status : public triton::common::Error { + public: + // Construct a status from a code with no message. + explicit Status(Code code = Code::SUCCESS) : Error(code) {} + + // Construct a status from a code and message. + explicit Status(Code code, const std::string& msg) : Error(code, msg) {} + + // Construct a status from a code and message. + explicit Status(const Error& error) : Error(error) {} + + // Convenience "success" value. Can be used as Error::Success to + // indicate no error. + static const Status Success; + + // Return the code for this status. + Code StatusCode() const { return code_; } +}; + +// Return the Status::Code corresponding to a +// TRITONSERVER_Error_Code. +Status::Code TritonCodeToStatusCode(TRITONSERVER_Error_Code code); + +// Return the TRITONSERVER_Error_Code corresponding to a +// Status::Code. +TRITONSERVER_Error_Code StatusCodeToTritonCode(Status::Code status_code); + +// Converts the common Error to Status object +Status CommonErrorToStatus(const triton::common::Error& error); + +// If status is non-OK, return the Status. +#define RETURN_IF_ERROR(S) \ + do { \ + const Status& status__ = (S); \ + if (!status__.IsOk()) { \ + return status__; \ + } \ + } while (false) + +// If TRITONSERVER error is non-OK, return the corresponding status. +#define RETURN_IF_TRITONSERVER_ERROR(E) \ + do { \ + TRITONSERVER_Error* err__ = (E); \ + if (err__ != nullptr) { \ + Status status__ = Status( \ + TritonCodeToStatusCode(TRITONSERVER_ErrorCode(err__)), \ + TRITONSERVER_ErrorMessage(err__)); \ + TRITONSERVER_ErrorDelete(err__); \ + return status__; \ + } \ + } while (false) + +// If status is non-OK, return the corresponding TRITONSERVER_Error. +#define RETURN_TRITONSERVER_ERROR_IF_ERROR(S) \ + do { \ + const Status& status__ = (S); \ + if (!status__.IsOk()) { \ + return TRITONSERVER_ErrorNew( \ + StatusCodeToTritonCode(status__.StatusCode()), \ + status__.Message().c_str()); \ + } \ + } while (false) + +}} // namespace triton::core diff --git a/3rdparty/core-r22.12/src/test/async_work_queue_test.cc b/3rdparty/core-r22.12/src/test/async_work_queue_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..488e1daf1ace2c812becca55bc5e8f41938b9e40 --- /dev/null +++ b/3rdparty/core-r22.12/src/test/async_work_queue_test.cc @@ -0,0 +1,245 @@ +// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#include "gtest/gtest.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include "triton/common/async_work_queue.h" + +namespace tc = triton::common; + +namespace { + +// Wrapper of AsyncWorkQueue class to expose Reset() for unit testing +class TestingAsyncWorkQueue : public tc::AsyncWorkQueue { + public: + static void Reset() { AsyncWorkQueue::Reset(); } +}; + +class AsyncWorkQueueTest : public ::testing::Test { + protected: + void TearDown() override { TestingAsyncWorkQueue::Reset(); } +}; + +TEST_F(AsyncWorkQueueTest, InitZeroWorker) +{ + auto error = tc::AsyncWorkQueue::Initialize(0); + EXPECT_FALSE(error.IsOk()) << "Expect error when initialized with 0 worker"; +} + +TEST_F(AsyncWorkQueueTest, InitOneWorker) +{ + auto error = tc::AsyncWorkQueue::Initialize(1); + EXPECT_TRUE(error.IsOk()) << error.Message(); +} + +TEST_F(AsyncWorkQueueTest, InitFourWorker) +{ + auto error = tc::AsyncWorkQueue::Initialize(4); + EXPECT_TRUE(error.IsOk()) << error.Message(); +} + +TEST_F(AsyncWorkQueueTest, InitTwice) +{ + auto error = tc::AsyncWorkQueue::Initialize(4); + EXPECT_TRUE(error.IsOk()) << error.Message(); + error = tc::AsyncWorkQueue::Initialize(2); + EXPECT_FALSE(error.IsOk()) << "Expect error from initializing twice"; +} + +TEST_F(AsyncWorkQueueTest, WorkerCountUninitialized) +{ + EXPECT_EQ(tc::AsyncWorkQueue::WorkerCount(), (size_t)0) + << "Expect 0 worker count for uninitialized queue"; +} + +TEST_F(AsyncWorkQueueTest, WorkerCountInitialized) +{ + auto error = tc::AsyncWorkQueue::Initialize(4); + EXPECT_TRUE(error.IsOk()) << error.Message(); + EXPECT_EQ(tc::AsyncWorkQueue::WorkerCount(), (size_t)4) + << "Expect 4 worker count for initialized queue"; +} + + +TEST_F(AsyncWorkQueueTest, RunTasksInParallel) +{ + auto AddTwoFn = [](const std::vector& lhs, const std::vector& rhs, + std::promise>* res) { + std::vector lres; + lres.reserve(lhs.size()); + for (size_t idx = 0; idx < lhs.size(); idx++) { + lres.push_back(lhs[idx] + rhs[idx]); + } + res->set_value(lres); + }; + + size_t task_count = 8; + std::vector> operands; + std::vector> expected_results; + { + // Use large element count to reduce the async work queue overhead + size_t element_count = 1 << 24; + auto RandHalfIntFn = std::bind( + std::uniform_int_distribution<>{std::numeric_limits::min() / 2, + std::numeric_limits::max() / 2}, + std::default_random_engine{}); + for (size_t tc = 0; tc < task_count + 1; tc++) { + expected_results.push_back(std::vector()); + operands.push_back(std::vector()); + operands.back().reserve(element_count); + for (size_t ec = 0; ec < element_count; ec++) { + operands.back().push_back(RandHalfIntFn()); + } + } + } + + // Get serialized time as baseline and store expected results + uint64_t serialized_duration = 0; + { + std::vector>> res(task_count); + + auto start_ts = + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now().time_since_epoch()) + .count(); + + for (size_t count = 0; count < task_count; count++) { + AddTwoFn(operands[count], operands[count + 1], &res[count]); + } + + auto end_ts = + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now().time_since_epoch()) + .count(); + + for (size_t count = 0; count < task_count; count++) { + expected_results[count] = std::move(res[count].get_future().get()); + } + serialized_duration = end_ts - start_ts; + } + + auto error = tc::AsyncWorkQueue::Initialize(4); + ASSERT_TRUE(error.IsOk()) << error.Message(); + + uint64_t parallelized_duration = 0; + { + std::vector>> ps(task_count); + std::vector>> fs; + for (auto& p : ps) { + fs.emplace_back(std::move(p.get_future())); + } + + auto start_ts = + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now().time_since_epoch()) + .count(); + + for (size_t count = 0; count < task_count; count++) { + tc::AsyncWorkQueue::AddTask([&AddTwoFn, &operands, &ps, count]() mutable { + AddTwoFn(operands[count], operands[count + 1], &ps[count]); + }); + } + for (size_t count = 0; count < task_count; count++) { + fs[count].wait(); + } + + auto end_ts = + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now().time_since_epoch()) + .count(); + + parallelized_duration = end_ts - start_ts; + // FIXME manual testing shows parallelized time is between 30% to 33.3% for + // 128 M total elements + EXPECT_LT(parallelized_duration, serialized_duration / 3) + << "Expected parallelized work was completed within 1/3 of serialized " + "time"; + for (size_t count = 0; count < task_count; count++) { + auto res = std::move(fs[count].get()); + EXPECT_EQ(res, expected_results[count]) + << "Mismatched parallelized result"; + } + } +} + +TEST_F(AsyncWorkQueueTest, RunTasksFIFO) +{ + auto CaptureTimestampFn = [](std::promise* res) { + res->set_value( + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now().time_since_epoch()) + .count()); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + }; + + size_t task_count = 8; + std::vector> ps(task_count); + + auto error = tc::AsyncWorkQueue::Initialize(2); + ASSERT_TRUE(error.IsOk()) << error.Message(); + + std::vector> barrier(2); + tc::AsyncWorkQueue::AddTask([&barrier]() mutable { + barrier[0].get_future().get(); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + }); + tc::AsyncWorkQueue::AddTask([&barrier]() mutable { + barrier[1].get_future().get(); + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + }); + for (size_t count = 0; count < task_count; count++) { + tc::AsyncWorkQueue::AddTask([count, &CaptureTimestampFn, &ps]() mutable { + CaptureTimestampFn(&ps[count]); + }); + } + + // Signal to start the work + barrier[0].set_value(); + barrier[1].set_value(); + + uint64_t prev_ts = 0; + for (size_t count = 0; count < task_count; count++) { + uint64_t curr_ts = ps[count].get_future().get(); + EXPECT_LT(prev_ts, curr_ts) + << "Expected async work is processed in FIFO order"; + } +} + +} // namespace + +int +main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/3rdparty/core-r22.12/src/test/memory_test.cc b/3rdparty/core-r22.12/src/test/memory_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6bd3b6293949425bcac6690899473488764645b9 --- /dev/null +++ b/3rdparty/core-r22.12/src/test/memory_test.cc @@ -0,0 +1,402 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#include "gtest/gtest.h" + +#include +#include "cuda_memory_manager.h" +#include "cuda_utils.h" +#include "memory.h" +#include "pinned_memory_manager.h" + +namespace tc = triton::core; + +namespace { + +#define CHECK_POINTER_ATTRIBUTES(ptr__, type__, device__) \ + do { \ + cudaPointerAttributes attr; \ + auto cuerr = cudaPointerGetAttributes(&attr, ptr__); \ + ASSERT_TRUE(cuerr == cudaSuccess) \ + << "Failed to get CUDA pointer attributes: " \ + << cudaGetErrorString(cuerr); \ + EXPECT_TRUE(attr.type == type__) \ + << "Expect pointer with type " << type__ << ", got: " << attr.type; \ + if (attr.type == cudaMemoryTypeDevice) { \ + EXPECT_TRUE(attr.device == device__) \ + << "Expect allocation on CUDA device " << device__ \ + << ", got: " << attr.device; \ + } \ + } while (false) + +// Wrapper of CudaMemoryManager class to expose Reset() for unit testing +class TestingCudaMemoryManager : public tc::CudaMemoryManager { + public: + static void Reset() { CudaMemoryManager::Reset(); } +}; + +class CudaMemoryManagerTest : public ::testing::Test { + protected: + void SetUp() override + { + // Default memory manager options + options_.min_supported_compute_capability_ = 6.0; + options_.memory_pool_byte_size_ = {{0, 1 << 10}}; + } + + void TearDown() override { TestingCudaMemoryManager::Reset(); } + + tc::CudaMemoryManager::Options options_; +}; + +TEST_F(CudaMemoryManagerTest, InitOOM) +{ + // Set to reserve too much memory + double cc = 6.0; + std::map s{{0, uint64_t(1) << 40 /* 1024 GB */}}; + const tc::CudaMemoryManager::Options options{cc, s}; + auto status = tc::CudaMemoryManager::Create(options); + EXPECT_FALSE(status.IsOk()) << "Expect creation error"; +} + +TEST_F(CudaMemoryManagerTest, InitSuccess) +{ + double cc = 6.0; + std::map s{{0, 1 << 10 /* 1024 bytes */}}; + const tc::CudaMemoryManager::Options options{cc, s}; + auto status = tc::CudaMemoryManager::Create(options); + EXPECT_TRUE(status.IsOk()) << status.Message(); +} + +TEST_F(CudaMemoryManagerTest, InitNoDeviceConfig) +{ + double cc = 6.0; + std::map s; + const tc::CudaMemoryManager::Options options{cc, s}; + auto status = tc::CudaMemoryManager::Create(options); + EXPECT_TRUE(status.IsOk()) << status.Message(); + + void* ptr = nullptr; + status = tc::CudaMemoryManager::Alloc(&ptr, 1, 0); + ASSERT_FALSE(status.IsOk()) << "Unexpected successful allocation"; +} + +TEST_F(CudaMemoryManagerTest, InitZeroByte) +{ + double cc = 6.0; + std::map s{{0, 0}}; + const tc::CudaMemoryManager::Options options{cc, s}; + auto status = tc::CudaMemoryManager::Create(options); + EXPECT_TRUE(status.IsOk()) << status.Message(); + + void* ptr = nullptr; + status = tc::CudaMemoryManager::Alloc(&ptr, 1, 0); + ASSERT_FALSE(status.IsOk()) << "Unexpected successful allocation"; +} + +TEST_F(CudaMemoryManagerTest, AllocSuccess) +{ + auto status = tc::CudaMemoryManager::Create(options_); + ASSERT_TRUE(status.IsOk()) << status.Message(); + + void* ptr = nullptr; + status = tc::CudaMemoryManager::Alloc(&ptr, 1024, 0); + ASSERT_TRUE(status.IsOk()) << status.Message(); + ASSERT_TRUE(ptr) << "Expect pointer to allocated buffer"; + // check if returned pointer is CUDA pointer + CHECK_POINTER_ATTRIBUTES(ptr, cudaMemoryTypeDevice, 0); +} + +TEST_F(CudaMemoryManagerTest, AllocFail) +{ + auto status = tc::CudaMemoryManager::Create(options_); + ASSERT_TRUE(status.IsOk()) << status.Message(); + + void* ptr = nullptr; + status = tc::CudaMemoryManager::Alloc(&ptr, 2048, 0); + ASSERT_FALSE(status.IsOk()) << "Unexpected successful allocation"; +} + +TEST_F(CudaMemoryManagerTest, MultipleAlloc) +{ + auto status = tc::CudaMemoryManager::Create(options_); + ASSERT_TRUE(status.IsOk()) << status.Message(); + + void* first_ptr = nullptr; + status = tc::CudaMemoryManager::Alloc(&first_ptr, 600, 0); + ASSERT_TRUE(status.IsOk()) << status.Message(); + ASSERT_TRUE(first_ptr) << "Expect pointer to allocated buffer"; + // check if returned pointer is CUDA pointer + CHECK_POINTER_ATTRIBUTES(first_ptr, cudaMemoryTypeDevice, 0); + + // 512 + 600 > 1024 + void* second_ptr = nullptr; + status = tc::CudaMemoryManager::Alloc(&second_ptr, 512, 0); + ASSERT_FALSE(status.IsOk()) << "Unexpected successful allocation"; + + // Free the first pointer and retry the second one + status = tc::CudaMemoryManager::Free(first_ptr, 0); + EXPECT_TRUE(status.IsOk()) << status.Message(); + status = tc::CudaMemoryManager::Alloc(&second_ptr, 512, 0); + ASSERT_TRUE(status.IsOk()) << status.Message(); + ASSERT_TRUE(second_ptr) << "Expect pointer to allocated buffer"; + // check if returned pointer is CUDA pointer + CHECK_POINTER_ATTRIBUTES(second_ptr, cudaMemoryTypeDevice, 0); +} + +TEST_F(CudaMemoryManagerTest, MultipleDevice) +{ + std::set supported_gpus; + auto status = tc::GetSupportedGPUs( + &supported_gpus, options_.min_supported_compute_capability_); + ASSERT_TRUE(status.IsOk()) << status.Message(); + ASSERT_GE(supported_gpus.size(), size_t(2)) + << "Test requires at least two supported CUDA devices"; + + { + double cc = 6.0; + std::map s; + // Only enough memory is only reserved in one of the devices + s[*supported_gpus.begin()] = 32; + s[*(++supported_gpus.begin())] = 1024; + const tc::CudaMemoryManager::Options options{cc, s}; + status = tc::CudaMemoryManager::Create(options); + ASSERT_TRUE(status.IsOk()) << status.Message(); + } + + void* ptr = nullptr; + // Allocation on small device + int small_device = *supported_gpus.begin(); + status = tc::CudaMemoryManager::Alloc(&ptr, 1024, small_device); + ASSERT_FALSE(status.IsOk()) << "Unexpected successful allocation"; + + // Allocation on large device + int large_device = *(++supported_gpus.begin()); + status = tc::CudaMemoryManager::Alloc(&ptr, 1024, large_device); + ASSERT_TRUE(status.IsOk()) << status.Message(); + ASSERT_TRUE(ptr) << "Expect pointer to allocated buffer"; + // check if returned pointer is CUDA pointer + CHECK_POINTER_ATTRIBUTES(ptr, cudaMemoryTypeDevice, large_device); + + // Free allocation ... + status = tc::CudaMemoryManager::Free(ptr, small_device); + EXPECT_FALSE(status.IsOk()) << "Unexpected deallocation on wrong device"; + status = tc::CudaMemoryManager::Free(ptr, large_device); + EXPECT_TRUE(status.IsOk()) << status.Message(); +} + +class AllocatedMemoryTest : public ::testing::Test { + protected: + // Per-test-suite set-up. + static void SetUpTestSuite() + { + // Pinned memory manager + { + tc::PinnedMemoryManager::Options options{1024}; + auto status = tc::PinnedMemoryManager::Create(options); + ASSERT_TRUE(status.IsOk()) << status.Message(); + } + } + + // Set up CUDA memory manager per test for special fallback case + void SetUp() override + { + tc::CudaMemoryManager::Options options{6.0, {{0, 1 << 10}}}; + auto status = tc::CudaMemoryManager::Create(options); + ASSERT_TRUE(status.IsOk()) << status.Message(); + } + + void TearDown() override { TestingCudaMemoryManager::Reset(); } +}; + +TEST_F(AllocatedMemoryTest, AllocGPU) +{ + size_t expect_size = 512, actual_size; + TRITONSERVER_MemoryType expect_type = TRITONSERVER_MEMORY_GPU, actual_type; + int64_t expect_id = 0, actual_id; + tc::AllocatedMemory memory(expect_size, expect_type, expect_id); + + auto ptr = memory.BufferAt(0, &actual_size, &actual_type, &actual_id); + EXPECT_EQ(expect_size, actual_size) + << "Expect size: " << expect_size << ", got: " << actual_size; + EXPECT_EQ(expect_type, actual_type) + << "Expect type: " << expect_type << ", got: " << actual_type; + EXPECT_EQ(expect_id, actual_id) + << "Expect id: " << expect_id << ", got: " << actual_id; + + // Sanity check on the pointer property + CHECK_POINTER_ATTRIBUTES(ptr, cudaMemoryTypeDevice, expect_id); +} + +TEST_F(AllocatedMemoryTest, AllocPinned) +{ + size_t expect_size = 512, actual_size; + TRITONSERVER_MemoryType expect_type = TRITONSERVER_MEMORY_CPU_PINNED, + actual_type; + int64_t expect_id = 0, actual_id; + tc::AllocatedMemory memory(expect_size, expect_type, expect_id); + + auto ptr = memory.BufferAt(0, &actual_size, &actual_type, &actual_id); + EXPECT_EQ(expect_size, actual_size) + << "Expect size: " << expect_size << ", got: " << actual_size; + EXPECT_EQ(expect_type, actual_type) + << "Expect type: " << expect_type << ", got: " << actual_type; + EXPECT_EQ(expect_id, actual_id) + << "Expect id: " << expect_id << ", got: " << actual_id; + + // Sanity check on the pointer property + CHECK_POINTER_ATTRIBUTES(ptr, cudaMemoryTypeHost, expect_id); +} + +TEST_F(AllocatedMemoryTest, AllocFallback) +{ + // Each allocation uses half of the target reserved memory + size_t expect_size = 600, actual_size; + TRITONSERVER_MemoryType expect_type = TRITONSERVER_MEMORY_GPU, actual_type; + int64_t expect_id = 0, actual_id; + + // First allocation + tc::AllocatedMemory cuda_memory(expect_size, expect_type, expect_id); + + auto ptr = cuda_memory.BufferAt(0, &actual_size, &actual_type, &actual_id); + EXPECT_EQ(expect_size, actual_size) + << "Expect size: " << expect_size << ", got: " << actual_size; + EXPECT_EQ(expect_type, actual_type) + << "Expect type: " << expect_type << ", got: " << actual_type; + EXPECT_EQ(expect_id, actual_id) + << "Expect id: " << expect_id << ", got: " << actual_id; + + // Sanity check on the pointer property + CHECK_POINTER_ATTRIBUTES(ptr, cudaMemoryTypeDevice, expect_id); + + // Second allocation, should trigger fallback from CUDA -> pinned memory + tc::AllocatedMemory pinned_memory(expect_size, expect_type, expect_id); + + ptr = pinned_memory.BufferAt(0, &actual_size, &actual_type, &actual_id); + EXPECT_EQ(expect_size, actual_size) + << "Expect size: " << expect_size << ", got: " << actual_size; + EXPECT_EQ(TRITONSERVER_MEMORY_CPU_PINNED, actual_type) + << "Expect type: " << TRITONSERVER_MEMORY_CPU_PINNED + << ", got: " << actual_type; + + // Sanity check on the pointer property + CHECK_POINTER_ATTRIBUTES(ptr, cudaMemoryTypeHost, expect_id); + + // Third allocation, CUDA -> pinned -> non-pinned + tc::AllocatedMemory system_memory(expect_size, expect_type, expect_id); + + ptr = system_memory.BufferAt(0, &actual_size, &actual_type, &actual_id); + EXPECT_EQ(expect_size, actual_size) + << "Expect size: " << expect_size << ", got: " << actual_size; + EXPECT_EQ(TRITONSERVER_MEMORY_CPU, actual_type) + << "Expect type: " << TRITONSERVER_MEMORY_CPU_PINNED + << ", got: " << actual_type; + + // Sanity check on the pointer property + CHECK_POINTER_ATTRIBUTES(ptr, cudaMemoryTypeUnregistered, expect_id); +} + +TEST_F(AllocatedMemoryTest, AllocFallbackNoCuda) +{ + // Test fallback in the case where CUDA memory manager is not properly created + TestingCudaMemoryManager::Reset(); + + size_t expect_size = 600, actual_size; + TRITONSERVER_MemoryType expect_type = TRITONSERVER_MEMORY_GPU, actual_type; + int64_t expect_id = 0, actual_id; + + // CUDA memory allocation should trigger fallback to allocate pinned memory + tc::AllocatedMemory pinned_memory(expect_size, expect_type, expect_id); + + auto ptr = pinned_memory.BufferAt(0, &actual_size, &actual_type, &actual_id); + EXPECT_EQ(expect_size, actual_size) + << "Expect size: " << expect_size << ", got: " << actual_size; + EXPECT_EQ(TRITONSERVER_MEMORY_CPU_PINNED, actual_type) + << "Expect type: " << TRITONSERVER_MEMORY_CPU_PINNED + << ", got: " << actual_type; + + // Sanity check on the pointer property + CHECK_POINTER_ATTRIBUTES(ptr, cudaMemoryTypeHost, expect_id); +} + +TEST_F(AllocatedMemoryTest, Release) +{ + // Similar to above, but verify that the memory will be released once + // out of scope + // Each allocation uses half of the target reserved memory + size_t expect_size = 600, actual_size; + TRITONSERVER_MemoryType expect_type = TRITONSERVER_MEMORY_GPU, actual_type; + int64_t expect_id = 0, actual_id; + + { + // First allocation + tc::AllocatedMemory cuda_memory(expect_size, expect_type, expect_id); + + auto ptr = cuda_memory.BufferAt(0, &actual_size, &actual_type, &actual_id); + EXPECT_EQ(expect_size, actual_size) + << "Expect size: " << expect_size << ", got: " << actual_size; + EXPECT_EQ(expect_type, actual_type) + << "Expect type: " << expect_type << ", got: " << actual_type; + EXPECT_EQ(expect_id, actual_id) + << "Expect id: " << expect_id << ", got: " << actual_id; + + // Sanity check on the pointer property + CHECK_POINTER_ATTRIBUTES(ptr, cudaMemoryTypeDevice, expect_id); + + // Second allocation, should trigger fallback from CUDA -> pinned memory + tc::AllocatedMemory pinned_memory(expect_size, expect_type, expect_id); + + ptr = pinned_memory.BufferAt(0, &actual_size, &actual_type, &actual_id); + EXPECT_EQ(expect_size, actual_size) + << "Expect size: " << expect_size << ", got: " << actual_size; + EXPECT_EQ(TRITONSERVER_MEMORY_CPU_PINNED, actual_type) + << "Expect type: " << TRITONSERVER_MEMORY_CPU_PINNED + << ", got: " << actual_type; + + // Sanity check on the pointer property + CHECK_POINTER_ATTRIBUTES(ptr, cudaMemoryTypeHost, expect_id); + } + + // Third allocation, should not trigger fallback + tc::AllocatedMemory memory(expect_size, expect_type, expect_id); + + auto ptr = memory.BufferAt(0, &actual_size, &actual_type, &actual_id); + EXPECT_EQ(expect_size, actual_size) + << "Expect size: " << expect_size << ", got: " << actual_size; + EXPECT_EQ(expect_type, actual_type) + << "Expect type: " << expect_type << ", got: " << actual_type; + + // Sanity check on the pointer property + CHECK_POINTER_ATTRIBUTES(ptr, cudaMemoryTypeDevice, expect_id); +} + +} // namespace + +int +main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/3rdparty/core-r22.12/src/test/metrics_api_test.cc b/3rdparty/core-r22.12/src/test/metrics_api_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f99595a6cd9c8e20fcd8205addcaf9a1b132021f --- /dev/null +++ b/3rdparty/core-r22.12/src/test/metrics_api_test.cc @@ -0,0 +1,678 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifdef TRITON_ENABLE_METRICS + +#include +#include +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "metric_family.h" +#include "triton/common/logging.h" +#include "triton/core/tritonserver.h" + +namespace tc = triton::core; + +namespace { + +using ::testing::HasSubstr; + +#define FAIL_TEST_IF_ERR(X, MSG) \ + do { \ + std::shared_ptr err__((X), TRITONSERVER_ErrorDelete); \ + ASSERT_TRUE((err__ == nullptr)) \ + << "error: " << (MSG) << ": " \ + << TRITONSERVER_ErrorCodeString(err__.get()) << " - " \ + << TRITONSERVER_ErrorMessage(err__.get()); \ + } while (false) + +/* Helpers */ + +// Get serialized metrics string from C API +void +GetMetrics(TRITONSERVER_Server* server, std::string* metrics_str) +{ + // Check metrics via C API + ASSERT_NE(server, nullptr); + TRITONSERVER_Metrics* metrics = nullptr; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerMetrics(server, &metrics), "fetch metrics"); + const char* base; + size_t byte_size; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricsFormatted( + metrics, TRITONSERVER_METRIC_PROMETHEUS, &base, &byte_size), + "format metrics string"); + *metrics_str = std::string(base, byte_size); + TRITONSERVER_MetricsDelete(metrics); +} + +// Count number of times substr appears in s +int +CountMatches(const std::string s, const std::string substr) +{ + int num_matches = 0; + std::string::size_type pos = 0; + while ((pos = s.find(substr, pos)) != std::string::npos) { + num_matches++; + pos += substr.length(); + } + return num_matches; +} + +int +NumMetricMatches(TRITONSERVER_Server* server, const std::string substr) +{ + std::string metrics_str; + GetMetrics(server, &metrics_str); + const int num_matches = CountMatches(metrics_str, substr); + return num_matches; +} + +// Add two metrics with the same labels from the same metric family +// and verify they refer to the same metric/value +void +DupeMetricHelper( + TRITONSERVER_Server* server, + std::vector labels) +{ + // Create metric family + TRITONSERVER_MetricFamily* family = nullptr; + TRITONSERVER_MetricKind kind = TRITONSERVER_METRIC_KIND_COUNTER; + const char* name = "dupe_metric_test"; + const char* description = "dupe metric description"; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricFamilyNew(&family, kind, name, description), + "Creating new metric family1"); + + // Create metric + TRITONSERVER_Metric* metric1 = nullptr; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricNew(&metric1, family, labels.data(), labels.size()), + "Creating new metric"); + + // Create duplicate metric + TRITONSERVER_Metric* metric2 = nullptr; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricNew(&metric2, family, labels.data(), labels.size()), + "Creating new metric"); + + // Verify dupe metrics reference same underlying metric + double value1 = -1; + double value2 = -1; + double inc = 7.5; + + // Verify initial values of zero + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricValue(metric1, &value1), + "query metric value after increment"); + ASSERT_EQ(value1, 0); + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricValue(metric2, &value2), + "query metric value after increment"); + ASSERT_EQ(value2, 0); + + // Increment metric 1, check metric 2 == metric 1 + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricIncrement(metric1, inc), "increase metric value"); + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricValue(metric1, &value1), + "query metric value after increment"); + ASSERT_EQ(value1, inc); + + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricValue(metric2, &value2), + "query metric value after increment"); + ASSERT_EQ(value1, value2); + std::cout << "metric1 value: " << value1 << " == metric2 value: " << value2 + << std::endl; + + // Assert custom metric/family remains when there's still a reference to it + FAIL_TEST_IF_ERR(TRITONSERVER_MetricDelete(metric1), "delete metric1"); + ASSERT_EQ(NumMetricMatches(server, description), 1); + + // Assert custom metric/family not displayed after all metrics are deleted + FAIL_TEST_IF_ERR(TRITONSERVER_MetricDelete(metric2), "delete metric2"); + ASSERT_EQ(NumMetricMatches(server, description), 0); + FAIL_TEST_IF_ERR(TRITONSERVER_MetricFamilyDelete(family), "delete family"); +} + +void +MetricAPIHelper(TRITONSERVER_Metric* metric, TRITONSERVER_MetricKind kind) +{ + double value = -1; + double prev_value = -1; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricValue(metric, &value), "query metric initial value"); + // Value should be zero initially + ASSERT_EQ(value, 0.0); + + // Increment positively + double increment = 1729.0; + prev_value = value; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricIncrement(metric, increment), "increase metric value"); + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricValue(metric, &value), + "query metric value after positive increment"); + ASSERT_EQ(value, prev_value + increment); + + // Increment negatively + double decrement = -3.14; + prev_value = value; + auto err = TRITONSERVER_MetricIncrement(metric, decrement); + switch (kind) { + case TRITONSERVER_METRIC_KIND_COUNTER: { + ASSERT_NE(err, nullptr); + break; + } + case TRITONSERVER_METRIC_KIND_GAUGE: { + ASSERT_EQ(err, nullptr); + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricValue(metric, &value), + "query metric value after negative increment"); + ASSERT_EQ(value, prev_value + decrement); + break; + } + default: + ASSERT_TRUE(false); + break; + } + + // Set + double set_value = 42.0; + err = TRITONSERVER_MetricSet(metric, set_value); + switch (kind) { + case TRITONSERVER_METRIC_KIND_COUNTER: { + ASSERT_NE(err, nullptr); + break; + } + case TRITONSERVER_METRIC_KIND_GAUGE: { + ASSERT_EQ(err, nullptr); + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricValue(metric, &value), + "query metric value after set"); + ASSERT_EQ(value, set_value); + break; + } + default: + ASSERT_TRUE(false); + break; + } + + // MetricKind + TRITONSERVER_MetricKind kind_tmp; + FAIL_TEST_IF_ERR( + TRITONSERVER_GetMetricKind(metric, &kind_tmp), "query metric kind"); + ASSERT_EQ(kind_tmp, kind); + TRITONSERVER_ErrorDelete(err); +} + + +// Test Fixture +class MetricsApiTest : public ::testing::Test { + protected: + // Run only once before entire set of tests + static void SetUpTestSuite() {} + // Run only once after entire set of tests + static void TearDownTestSuite() {} + + // Run before each test + void SetUp() override + { + // Create server object to pass when retrieving metrics. + // NOTE: It is currently not required to pass a valid server object to + // TRITONSERVER_ServerMetrics, but is more future-proof to include. + TRITONSERVER_ServerOptions* server_options = nullptr; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsNew(&server_options), + "creating server options"); + // Mute info output for the sake of this test, less output + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsSetLogInfo(server_options, false), + "disabling log INFO for brevity"); + // This test doesn't require the use of any models, so we use "." as repo + // and set ModelControlMode to EXPLICIT to avoid attempting to load models + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsSetModelRepositoryPath(server_options, "."), + "setting model repository path"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsSetModelControlMode( + server_options, TRITONSERVER_MODEL_CONTROL_EXPLICIT), + "setting model control mode"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerNew(&server_, server_options), "creating server"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsDelete(server_options), + "deleting server options"); + } + + // Run after each test + void TearDown() override + { + FAIL_TEST_IF_ERR(TRITONSERVER_ServerDelete(server_), "deleting server"); + } + + static TRITONSERVER_Server* server_; +}; + +TRITONSERVER_Server* MetricsApiTest::server_ = nullptr; + +// Test end-to-end flow of Generic Metrics API for Counter metric +TEST_F(MetricsApiTest, TestCounterEndToEnd) +{ + // Create metric family + TRITONSERVER_MetricFamily* family; + TRITONSERVER_MetricKind kind = TRITONSERVER_METRIC_KIND_COUNTER; + const char* name = "custom_counter_example"; + const char* description = "this is an example counter metric added via API."; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricFamilyNew(&family, kind, name, description), + "Creating new metric family"); + + // Create metric + TRITONSERVER_Metric* metric; + std::vector labels; + labels.emplace_back(TRITONSERVER_ParameterNew( + "example1", TRITONSERVER_PARAMETER_STRING, "counter_label1")); + labels.emplace_back(TRITONSERVER_ParameterNew( + "example2", TRITONSERVER_PARAMETER_STRING, "counter_label2")); + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricNew(&metric, family, labels.data(), labels.size()), + "Creating new metric"); + for (const auto label : labels) { + TRITONSERVER_ParameterDelete(const_cast(label)); + } + + // Run through metric APIs and assert correctness + MetricAPIHelper(metric, kind); + + // Assert custom metric is reported and found in output + ASSERT_EQ(NumMetricMatches(server_, description), 1); + + // Cleanup + FAIL_TEST_IF_ERR(TRITONSERVER_MetricDelete(metric), "delete metric"); + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricFamilyDelete(family), "delete metric family"); + + // Assert custom metric/family is unregistered and no longer in output + ASSERT_EQ(NumMetricMatches(server_, description), 0); +} + +// Test end-to-end flow of Generic Metrics API for Gauge metric +TEST_F(MetricsApiTest, TestGaugeEndToEnd) +{ + // Create metric family + TRITONSERVER_MetricFamily* family; + TRITONSERVER_MetricKind kind = TRITONSERVER_METRIC_KIND_GAUGE; + const char* name = "custom_gauge_example"; + const char* description = "this is an example gauge metric added via API."; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricFamilyNew(&family, kind, name, description), + "Creating new metric family"); + + // Create metric + TRITONSERVER_Metric* metric; + std::vector labels; + labels.emplace_back(TRITONSERVER_ParameterNew( + "example1", TRITONSERVER_PARAMETER_STRING, "gauge_label1")); + labels.emplace_back(TRITONSERVER_ParameterNew( + "example2", TRITONSERVER_PARAMETER_STRING, "gauge_label2")); + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricNew(&metric, family, labels.data(), labels.size()), + "Creating new metric"); + for (const auto label : labels) { + TRITONSERVER_ParameterDelete(const_cast(label)); + } + + // Run through metric APIs and assert correctness + MetricAPIHelper(metric, kind); + + // Assert custom metric is reported and found in output + ASSERT_EQ(NumMetricMatches(server_, description), 1); + + // Cleanup + FAIL_TEST_IF_ERR(TRITONSERVER_MetricDelete(metric), "delete metric"); + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricFamilyDelete(family), "delete metric family"); + + // Assert custom metric/family is unregistered and no longer in output + ASSERT_EQ(NumMetricMatches(server_, description), 0); +} + +// Test that a duplicate metric family can't be added +// with a conflicting type/kind +TEST_F(MetricsApiTest, TestDupeMetricFamilyDiffKind) +{ + // Create metric family + TRITONSERVER_MetricFamily* family1 = nullptr; + TRITONSERVER_MetricKind kind1 = TRITONSERVER_METRIC_KIND_COUNTER; + const char* name = "diff_kind_test"; + const char* description = "diff kind description"; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricFamilyNew(&family1, kind1, name, description), + "Creating new metric family1"); + + // Create duplicate metric family with different kind + TRITONSERVER_MetricFamily* family2 = nullptr; + TRITONSERVER_MetricKind kind2 = TRITONSERVER_METRIC_KIND_GAUGE; + // Expect this to fail, can't have duplicate name of different kind + auto err = TRITONSERVER_MetricFamilyNew(&family2, kind2, name, description); + ASSERT_NE(err, nullptr); + ASSERT_EQ(family2, nullptr); + TRITONSERVER_ErrorDelete(err); +} + +// Test that a duplicate metric family name will still +// return the original metric family even if the description +// is changed +TEST_F(MetricsApiTest, TestDupeMetricFamilyDiffDescription) +{ + // Create metric family + TRITONSERVER_MetricFamily* family1 = nullptr; + TRITONSERVER_MetricKind kind = TRITONSERVER_METRIC_KIND_COUNTER; + const char* name = "diff_description_test"; + const char* description1 = "first description"; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricFamilyNew(&family1, kind, name, description1), + "Creating new metric family1"); + + // Create duplicate metric family + TRITONSERVER_MetricFamily* family2 = nullptr; + const char* description2 = "second description"; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricFamilyNew(&family2, kind, name, description2), + "Creating new metric family2"); + + // Assert MetricFamily is not reported until metrics are added to them + ASSERT_EQ(NumMetricMatches(server_, description1), 0); + ASSERT_EQ(NumMetricMatches(server_, description2), 0); + + // Add metric to family2 only, this will be shared by family1 as well + // since both families refer to the same underlying prometheus family + std::vector labels; + TRITONSERVER_Metric* metric2 = nullptr; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricNew(&metric2, family2, labels.data(), labels.size()), + "Creating new metric2"); + + // Assert MetricFamily is reported exactly once + // This confirms attempting to add a duplicate returns the existing family + ASSERT_EQ(NumMetricMatches(server_, description1), 1); + // The first description will be taken/kept if adding a duplicate + /// metric family name, even with a different description + ASSERT_EQ(NumMetricMatches(server_, description2), 0); + + // Delete one of the metric family references + // Specificailly, family1, because family2 is bound to metric2 + FAIL_TEST_IF_ERR(TRITONSERVER_MetricFamilyDelete(family1), "delete family1"); + + // Assert custom metric/family remains when family2 still references it + ASSERT_EQ(NumMetricMatches(server_, description1), 1); + + // Assert custom metric/family unregistered after last reference deleted + FAIL_TEST_IF_ERR(TRITONSERVER_MetricDelete(metric2), "delete metric2"); + FAIL_TEST_IF_ERR(TRITONSERVER_MetricFamilyDelete(family2), "delete family2"); + ASSERT_EQ(NumMetricMatches(server_, description1), 0); + ASSERT_EQ(NumMetricMatches(server_, description2), 0); +} + +// Test that adding a duplicate metric family will reuse the original +// and not add another entry to registry +TEST_F(MetricsApiTest, TestDupeMetricFamily) +{ + // Create metric family + TRITONSERVER_MetricFamily* family1 = nullptr; + TRITONSERVER_MetricKind kind = TRITONSERVER_METRIC_KIND_COUNTER; + const char* name = "dupe_metric_family_test"; + const char* description = "dupe metric family description"; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricFamilyNew(&family1, kind, name, description), + "Creating new metric family1"); + + // Create duplicate metric family + TRITONSERVER_MetricFamily* family2 = nullptr; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricFamilyNew(&family2, kind, name, description), + "Creating new metric family2"); + + // Assert MetricFamily is not reported until metrics are added to them + ASSERT_EQ(NumMetricMatches(server_, description), 0); + + // Create unique metrics for each family object. Both family objects + // will refer to the same prometheus family in the registry, so both + // metrics should be displayed under the family. + const char* metric_key = "custom_metric_key"; + std::vector labels1; + labels1.emplace_back(TRITONSERVER_ParameterNew( + metric_key, TRITONSERVER_PARAMETER_STRING, "label1")); + TRITONSERVER_Metric* metric1 = nullptr; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricNew(&metric1, family1, labels1.data(), labels1.size()), + "Creating new metric1"); + for (const auto label : labels1) { + TRITONSERVER_ParameterDelete(const_cast(label)); + } + + std::vector labels2; + labels2.emplace_back(TRITONSERVER_ParameterNew( + metric_key, TRITONSERVER_PARAMETER_STRING, "label2")); + TRITONSERVER_Metric* metric2 = nullptr; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricNew(&metric2, family2, labels2.data(), labels2.size()), + "Creating new metric2"); + for (const auto label : labels2) { + TRITONSERVER_ParameterDelete(const_cast(label)); + } + + // Assert MetricFamily is reported exactly once + // This confirms attempting to add a duplicate returns the existing family + ASSERT_EQ(NumMetricMatches(server_, description), 1); + // Assert we have two unique metrics + ASSERT_EQ(NumMetricMatches(server_, metric_key), 2); + + // Delete one of the metric family references + FAIL_TEST_IF_ERR(TRITONSERVER_MetricDelete(metric1), "delete metric1"); + FAIL_TEST_IF_ERR(TRITONSERVER_MetricFamilyDelete(family1), "delete family1"); + + // Assert custom family remains when there's still a reference to it + ASSERT_EQ(NumMetricMatches(server_, description), 1); + // Assert only one remaining metric after deleting one + ASSERT_EQ(NumMetricMatches(server_, metric_key), 1); + + // Assert custom metric/family unregistered after last reference deleted + FAIL_TEST_IF_ERR(TRITONSERVER_MetricDelete(metric2), "delete metric2"); + FAIL_TEST_IF_ERR(TRITONSERVER_MetricFamilyDelete(family2), "delete family2"); + ASSERT_EQ(NumMetricMatches(server_, description), 0); + // Assert no remaining metrics after deleting both + ASSERT_EQ(NumMetricMatches(server_, metric_key), 0); +} + +// Test that adding a duplicate metric will refer to the same +// underlying metric, and all instances will be updated +TEST_F(MetricsApiTest, TestDupeMetricLabels) +{ + std::vector labels; + labels.emplace_back(TRITONSERVER_ParameterNew( + "example1", TRITONSERVER_PARAMETER_STRING, "label1")); + labels.emplace_back(TRITONSERVER_ParameterNew( + "example2", TRITONSERVER_PARAMETER_STRING, "label2")); + + DupeMetricHelper(server_, labels); + + for (const auto label : labels) { + TRITONSERVER_ParameterDelete(const_cast(label)); + } +} + +// Test that adding a duplicate metric will refer to the same +// underlying metric, and all instances will be updated +TEST_F(MetricsApiTest, TestDupeMetricEmptyLabels) +{ + std::vector labels; + DupeMetricHelper(server_, labels); +} + +TEST_F(MetricsApiTest, TestOutOfOrderDelete) +{ + // Create metric family + TRITONSERVER_MetricFamily* family = nullptr; + TRITONSERVER_MetricKind kind = TRITONSERVER_METRIC_KIND_COUNTER; + const char* name = "out_of_order_delete"; + const char* description = "out of order delete test"; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricFamilyNew(&family, kind, name, description), + "Creating new metric family"); + + // Add metric to family + std::vector labels; + TRITONSERVER_Metric* metric = nullptr; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricNew(&metric, family, labels.data(), labels.size()), + "Creating new metric"); + + // Check that deleting metric family BEFORE metric fails + auto err = TRITONSERVER_MetricFamilyDelete(family); + EXPECT_THAT( + TRITONSERVER_ErrorMessage(err), HasSubstr("Must call MetricDelete")); + + // Check that deleting in correct order still works after above failure + FAIL_TEST_IF_ERR(TRITONSERVER_MetricDelete(metric), "deleting metric"); + FAIL_TEST_IF_ERR(TRITONSERVER_MetricFamilyDelete(family), "deleting family"); + TRITONSERVER_ErrorDelete(err); +} + +TEST_F(MetricsApiTest, TestMetricAfterFamilyDelete) +{ + // Create metric family + TRITONSERVER_MetricFamily* family = nullptr; + TRITONSERVER_MetricKind kind = TRITONSERVER_METRIC_KIND_GAUGE; + const char* name = "use_metric_after_family_delete"; + const char* description = "test using a metric after its family is deleted"; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricFamilyNew(&family, kind, name, description), + "Creating new metric family"); + + // Add metric to family + std::vector labels; + TRITONSERVER_Metric* metric = nullptr; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricNew(&metric, family, labels.data(), labels.size()), + "Creating new metric"); + + // Check that deleting metric family BEFORE metric fails + auto err = TRITONSERVER_MetricFamilyDelete(family); + EXPECT_THAT( + TRITONSERVER_ErrorMessage(err), HasSubstr("Must call MetricDelete")); + + // Use internal implementation to force deletion since C API checks first + // NOTE: This is for internal testing and should NOT be done by users. + delete reinterpret_cast(family); + + // Expected API calls to fail since metric has been invalidated by + // calling MetricFamilyDelete before MetricDelete + double value = -1; + err = TRITONSERVER_MetricValue(metric, &value); + EXPECT_THAT(TRITONSERVER_ErrorMessage(err), HasSubstr("invalidated")); + err = TRITONSERVER_MetricIncrement(metric, 1.0); + EXPECT_THAT(TRITONSERVER_ErrorMessage(err), HasSubstr("invalidated")); + err = TRITONSERVER_MetricSet(metric, 1.0); + EXPECT_THAT(TRITONSERVER_ErrorMessage(err), HasSubstr("invalidated")); + TRITONSERVER_ErrorDelete(err); +} + +// This test serves as a reminder to consider the ability to access +// internal core metrics via current metrics API and its implications. +TEST_F(MetricsApiTest, TestCoreMetricAccess) +{ + // Test accessing a metric family created in Triton Core + // through prometheus directly. Technically this metric can be + // updated manually by a user in addition to how the core manages + // the metric, but this should generally not be done. + TRITONSERVER_MetricFamily* family = nullptr; + TRITONSERVER_MetricKind kind = TRITONSERVER_METRIC_KIND_GAUGE; + // Pick existing core metric name here. + const char* name = "nv_gpu_power_limit"; + const char* description = ""; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricFamilyNew(&family, kind, name, description), + "Creating new metric family"); + // DLIS-4072: If registry->Remove() is implemented in MetricFamily we will + // we will probably want to make sure core metrics can not be deleted early. + FAIL_TEST_IF_ERR(TRITONSERVER_MetricFamilyDelete(family), "delete family"); +} + +TEST_F(MetricsApiTest, TestChildMetricTracking) +{ + // Create metric family + TRITONSERVER_MetricFamily* family = nullptr; + TRITONSERVER_MetricKind kind = TRITONSERVER_METRIC_KIND_GAUGE; + const char* name = "test_ref_counting"; + const char* description = "test using metric ref counting"; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricFamilyNew(&family, kind, name, description), + "Creating new metric family"); + + // Use internal implementation to verify correctness + auto tc_family = reinterpret_cast(family); + + // Create metric + TRITONSERVER_Metric* metric1 = nullptr; + std::vector labels; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricNew(&metric1, family, labels.data(), labels.size()), + "Creating new metric1"); + ASSERT_EQ(tc_family->NumMetrics(), 1); + + // Create duplicate metric + TRITONSERVER_Metric* metric2 = nullptr; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricNew(&metric2, family, labels.data(), labels.size()), + "Creating new metric2"); + ASSERT_EQ(tc_family->NumMetrics(), 2); + + + FAIL_TEST_IF_ERR(TRITONSERVER_MetricDelete(metric1), "delete metric1"); + ASSERT_EQ(tc_family->NumMetrics(), 1); + FAIL_TEST_IF_ERR(TRITONSERVER_MetricDelete(metric2), "delete metric2"); + ASSERT_EQ(tc_family->NumMetrics(), 0); + FAIL_TEST_IF_ERR(TRITONSERVER_MetricFamilyDelete(family), "delete family"); +} + +} // namespace + +int +main(int argc, char** argv) +{ +#ifdef TRITON_ENABLE_LOGGING + LOG_SET_VERBOSE(1); +#endif // TRITON_ENABLE_LOGGING + + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +#endif // TRITON_ENABLE_METRICS diff --git a/3rdparty/core-r22.12/src/test/pinned_memory_manager_test.cc b/3rdparty/core-r22.12/src/test/pinned_memory_manager_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d618047853192df9947ef5bbf7837dec07ebcf89 --- /dev/null +++ b/3rdparty/core-r22.12/src/test/pinned_memory_manager_test.cc @@ -0,0 +1,320 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#include "gtest/gtest.h" + +#include +#include +#include +#include +#include +#include "pinned_memory_manager.h" +#include "tritonserver_apis.h" + +namespace tc = triton::core; + +namespace { + +#define CHECK_POINTER_ATTRIBUTES(ptr__, type__, device__) \ + do { \ + cudaPointerAttributes attr; \ + auto cuerr = cudaPointerGetAttributes(&attr, ptr__); \ + ASSERT_TRUE(cuerr == cudaSuccess) \ + << "Failed to get CUDA pointer attributes: " \ + << cudaGetErrorString(cuerr); \ + EXPECT_TRUE(attr.type == type__) \ + << "Expect pointer with type " << type__ << ", got: " << attr.type; \ + if (attr.type == cudaMemoryTypeDevice) { \ + EXPECT_TRUE(attr.device == device__) \ + << "Expect allocation on CUDA device " << device__ \ + << ", got: " << attr.device; \ + } \ + } while (false) + +#define STORE_RESULT_AND_RETURN_IF_ERROR(metadata__, idx__, status__) \ + do { \ + if (!status__.IsOk()) { \ + std::lock_guard lk(metadata__->mtx_); \ + metadata__->results_[idx__] = status__.AsString(); \ + return; \ + } \ + } while (false) + +struct MemoryWorkMetadata { + MemoryWorkMetadata(size_t thread_count) + : thread_count_(thread_count), ready_count_(0), results_(thread_count, "") + { + } + size_t thread_count_; + size_t ready_count_; + std::vector results_; + std::mutex mtx_; + std::condition_variable cv_; +}; + +void +RunMemoryWork( + size_t idx, size_t alloc_size, bool allow_nonpinned_fallback, + MemoryWorkMetadata* metadata) +{ + // Prepare variable to hold input / output + std::unique_ptr input(new char[alloc_size]); + std::unique_ptr output(new char[alloc_size]); + + // Wait until all threads are issued + { + std::unique_lock lk(metadata->mtx_); + metadata->ready_count_++; + if (metadata->ready_count_ != metadata->thread_count_) { + while (metadata->ready_count_ != metadata->thread_count_) { + metadata->cv_.wait(lk); + } + } + metadata->cv_.notify_one(); + } + + // Simulate receive input data -> alloc and write to input buffer + // -> alloc and write to output buffer -> return output data + TRITONSERVER_MemoryType allocated_type = TRITONSERVER_MEMORY_GPU; + void* input_buffer = nullptr; + STORE_RESULT_AND_RETURN_IF_ERROR( + metadata, idx, + tc::PinnedMemoryManager::Alloc( + &input_buffer, alloc_size, &allocated_type, + allow_nonpinned_fallback)); + if ((!allow_nonpinned_fallback) && + (allocated_type != TRITONSERVER_MEMORY_CPU_PINNED)) { + tc::Status status( + tc::Status::Code::INVALID_ARG, "returned memory buffer is not pinned"); + STORE_RESULT_AND_RETURN_IF_ERROR(metadata, idx, status); + } + memcpy(input_buffer, input.get(), alloc_size); + void* output_buffer = nullptr; + STORE_RESULT_AND_RETURN_IF_ERROR( + metadata, idx, + tc::PinnedMemoryManager::Alloc( + &output_buffer, alloc_size, &allocated_type, + allow_nonpinned_fallback)); + if ((!allow_nonpinned_fallback) && + (allocated_type != TRITONSERVER_MEMORY_CPU_PINNED)) { + tc::Status status( + tc::Status::Code::INVALID_ARG, "returned memory buffer is not pinned"); + STORE_RESULT_AND_RETURN_IF_ERROR(metadata, idx, status); + } + memcpy(output_buffer, input_buffer, alloc_size); + memcpy(output.get(), output_buffer, alloc_size); + for (size_t offset = 0; offset < alloc_size; offset++) { + if (input.get()[offset] != output.get()[offset]) { + std::lock_guard lk(metadata->mtx_); + metadata->results_[idx] = + std::string("mismatch between input and output for work idx ") + + std::to_string(idx); + return; + } + } +} + +// Wrapper of PinnedMemoryManager class to expose Reset() for unit testing +class TestingPinnedMemoryManager : public tc::PinnedMemoryManager { + public: + static void Reset() { PinnedMemoryManager::Reset(); } +}; + +class PinnedMemoryManagerTest : public ::testing::Test { + protected: + void SetUp() override + { + // Default memory manager options + options_.pinned_memory_pool_byte_size_ = 1 << 10; + } + + void TearDown() override { TestingPinnedMemoryManager::Reset(); } + + tc::PinnedMemoryManager::Options options_; +}; + +TEST_F(PinnedMemoryManagerTest, InitOOM) +{ + // Set to reserve too much memory + options_.pinned_memory_pool_byte_size_ = uint64_t(1) << 40 /* 1024 GB */; + auto status = tc::PinnedMemoryManager::Create(options_); + // For pinned memory manager, it will still be created for "CPU fallback" + // allocation even if it fails to create pinned memory pool + EXPECT_TRUE(status.IsOk()) << status.Message(); +} + +TEST_F(PinnedMemoryManagerTest, InitSuccess) +{ + auto status = tc::PinnedMemoryManager::Create(options_); + EXPECT_TRUE(status.IsOk()) << status.Message(); +} + +TEST_F(PinnedMemoryManagerTest, InitZeroByte) +{ + options_.pinned_memory_pool_byte_size_ = 0; + auto status = tc::PinnedMemoryManager::Create(options_); + EXPECT_TRUE(status.IsOk()) << status.Message(); + + void* ptr = nullptr; + TRITONSERVER_MemoryType allocated_type = TRITONSERVER_MEMORY_GPU; + status = tc::PinnedMemoryManager::Alloc( + &ptr, 1, &allocated_type, false /* allow_nonpinned_fallback */); + ASSERT_FALSE(status.IsOk()) << "Unexpected successful allocation"; +} + +TEST_F(PinnedMemoryManagerTest, AllocSuccess) +{ + auto status = tc::PinnedMemoryManager::Create(options_); + ASSERT_TRUE(status.IsOk()) << status.Message(); + + void* ptr = nullptr; + TRITONSERVER_MemoryType allocated_type = TRITONSERVER_MEMORY_GPU; + status = tc::PinnedMemoryManager::Alloc( + &ptr, 512, &allocated_type, false /* allow_nonpinned_fallback */); + ASSERT_TRUE(status.IsOk()) << status.Message(); + ASSERT_TRUE(ptr) << "Expect pointer to allocated buffer"; + ASSERT_TRUE(allocated_type == TRITONSERVER_MEMORY_CPU_PINNED) + << "Expect pointer to pinned memory"; + // check if returned pointer is pinned memory pointer + CHECK_POINTER_ATTRIBUTES(ptr, cudaMemoryTypeHost, 0); +} + +TEST_F(PinnedMemoryManagerTest, AllocFallbackSuccess) +{ + auto status = tc::PinnedMemoryManager::Create(options_); + ASSERT_TRUE(status.IsOk()) << status.Message(); + + void* ptr = nullptr; + TRITONSERVER_MemoryType allocated_type = TRITONSERVER_MEMORY_GPU; + status = tc::PinnedMemoryManager::Alloc( + &ptr, 2048, &allocated_type, true /* allow_nonpinned_fallback */); + ASSERT_TRUE(status.IsOk()) << status.Message(); + ASSERT_TRUE(ptr) << "Expect pointer to allocated buffer"; + ASSERT_TRUE(allocated_type == TRITONSERVER_MEMORY_CPU) + << "Expect pointer to non-pinned memory"; + // check if returned pointer is non-pinned memory pointer + CHECK_POINTER_ATTRIBUTES(ptr, cudaMemoryTypeUnregistered, 0); +} + +TEST_F(PinnedMemoryManagerTest, AllocFail) +{ + auto status = tc::PinnedMemoryManager::Create(options_); + ASSERT_TRUE(status.IsOk()) << status.Message(); + + void* ptr = nullptr; + TRITONSERVER_MemoryType allocated_type = TRITONSERVER_MEMORY_GPU; + status = tc::PinnedMemoryManager::Alloc( + &ptr, 2048, &allocated_type, false /* allow_nonpinned_fallback */); + ASSERT_FALSE(status.IsOk()) << "Unexpected successful allocation"; +} + +TEST_F(PinnedMemoryManagerTest, MultipleAlloc) +{ + auto status = tc::PinnedMemoryManager::Create(options_); + ASSERT_TRUE(status.IsOk()) << status.Message(); + + void* first_ptr = nullptr; + TRITONSERVER_MemoryType allocated_type = TRITONSERVER_MEMORY_GPU; + status = tc::PinnedMemoryManager::Alloc( + &first_ptr, 600, &allocated_type, false /* allow_nonpinned_fallback */); + ASSERT_TRUE(status.IsOk()) << status.Message(); + ASSERT_TRUE(first_ptr) << "Expect pointer to allocated buffer"; + ASSERT_TRUE(allocated_type == TRITONSERVER_MEMORY_CPU_PINNED) + << "Expect pointer to pinned memory"; + // check if returned pointer is pinned memory pointer + CHECK_POINTER_ATTRIBUTES(first_ptr, cudaMemoryTypeHost, 0); + + // 512 + 600 > 1024 + void* second_ptr = nullptr; + status = tc::PinnedMemoryManager::Alloc( + &second_ptr, 512, &allocated_type, false /* allow_nonpinned_fallback */); + ASSERT_FALSE(status.IsOk()) << "Unexpected successful allocation"; + + // Free the first pointer and retry the second one + status = tc::PinnedMemoryManager::Free(first_ptr); + EXPECT_TRUE(status.IsOk()) << status.Message(); + status = tc::PinnedMemoryManager::Alloc( + &second_ptr, 512, &allocated_type, false /* allow_nonpinned_fallback */); + ASSERT_TRUE(status.IsOk()) << status.Message(); + ASSERT_TRUE(second_ptr) << "Expect pointer to allocated buffer"; + ASSERT_TRUE(allocated_type == TRITONSERVER_MEMORY_CPU_PINNED) + << "Expect pointer to pinned memory"; + // check if returned pointer is pinned memory pointer + CHECK_POINTER_ATTRIBUTES(second_ptr, cudaMemoryTypeHost, 0); +} + +TEST_F(PinnedMemoryManagerTest, ParallelAlloc) +{ + options_.pinned_memory_pool_byte_size_ = uint64_t(1) << 28 /* 256 MB */; + auto status = tc::PinnedMemoryManager::Create(options_); + ASSERT_TRUE(status.IsOk()) << status.Message(); + + // Create threads to perform operations on allocated memory in parallel + // Seems like for 1 MB alloc size (2 MB for both input and output), + // 100 threads is a good amount for pool manager not to use CPU fallback. + size_t thread_count = 100; + size_t allocated_size = 1 << 20 /* 1 MB */; + MemoryWorkMetadata metadata(thread_count); + std::vector threads; + for (size_t idx = 0; idx < thread_count; idx++) { + threads.emplace_back( + std::thread(RunMemoryWork, idx, allocated_size, false, &metadata)); + } + for (size_t idx = 0; idx < thread_count; idx++) { + threads[idx].join(); + EXPECT_TRUE(metadata.results_[idx].empty()) << metadata.results_[idx]; + } +} + + +TEST_F(PinnedMemoryManagerTest, ParallelAllocFallback) +{ + options_.pinned_memory_pool_byte_size_ = uint64_t(1) << 28 /* 256 MB */; + auto status = tc::PinnedMemoryManager::Create(options_); + ASSERT_TRUE(status.IsOk()) << status.Message(); + + // Create threads to perform operations on allocated memory in parallel + size_t thread_count = 128; + size_t allocated_size = 1 << 24 /* 4 MB */; + MemoryWorkMetadata metadata(thread_count); + std::vector threads; + for (size_t idx = 0; idx < thread_count; idx++) { + threads.emplace_back( + std::thread(RunMemoryWork, idx, allocated_size, true, &metadata)); + } + for (size_t idx = 0; idx < thread_count; idx++) { + threads[idx].join(); + EXPECT_TRUE(metadata.results_[idx].empty()) << metadata.results_[idx]; + } +} + +} // namespace + +int +main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/3rdparty/core-r22.12/src/test/query_test.cc b/3rdparty/core-r22.12/src/test/query_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a3e78176359704af1accaf44f8443386c933c4cc --- /dev/null +++ b/3rdparty/core-r22.12/src/test/query_test.cc @@ -0,0 +1,368 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "triton/core/tritonserver.h" + +namespace { + +using ::testing::HasSubstr; + +#define FAIL_TEST_IF_ERR(X, MSG) \ + do { \ + std::shared_ptr err__((X), TRITONSERVER_ErrorDelete); \ + ASSERT_TRUE((err__ == nullptr)) \ + << "error: " << (MSG) << ": " \ + << TRITONSERVER_ErrorCodeString(err__.get()) << " - " \ + << TRITONSERVER_ErrorMessage(err__.get()); \ + } while (false) + +using NameMap = + std::map>; +struct QueryTracker { + QueryTracker( + const char* tensor_name, size_t* byte_size, + TRITONSERVER_MemoryType memory_type, int64_t memory_type_id) + : has_name_(tensor_name != nullptr), has_byte_size_(byte_size != nullptr), + caller_preferred_type_(memory_type), + caller_preferred_id_(memory_type_id) + { + if (has_name_) { + name_ = tensor_name; + } + if (has_byte_size_) { + byte_size_ = *byte_size; + } + } + bool has_name_; + bool has_byte_size_; + std::string name_; + size_t byte_size_; + TRITONSERVER_MemoryType caller_preferred_type_; + int64_t caller_preferred_id_; +}; + +TRITONSERVER_Error* +ResponseAlloc( + TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name, + size_t byte_size, TRITONSERVER_MemoryType preferred_memory_type, + int64_t preferred_memory_type_id, void* userp, void** buffer, + void** buffer_userp, TRITONSERVER_MemoryType* actual_memory_type, + int64_t* actual_memory_type_id) +{ + auto& output_tracker = + (reinterpret_cast, NameMap>*>(userp) + ->second); + output_tracker.emplace( + tensor_name, + std::make_tuple( + preferred_memory_type, preferred_memory_type_id, byte_size)); + return nullptr; // Success +} + +TRITONSERVER_Error* +ResponseRelease( + TRITONSERVER_ResponseAllocator* allocator, void* buffer, void* buffer_userp, + size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id) +{ + return nullptr; // Success +} + +void +InferRequestComplete( + TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp) +{ + TRITONSERVER_InferenceRequestDelete(request); +} + +void +InferResponseComplete( + TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp) +{ + if (response != nullptr) { + // Notify that the completion. + std::promise* p = + reinterpret_cast*>(userp); + p->set_value(TRITONSERVER_InferenceResponseError(response)); + } + TRITONSERVER_InferenceResponseDelete(response); +} + +class QueryTest : public ::testing::Test { + protected: + static void SetUpTestSuite() + { + // Create the server... + TRITONSERVER_ServerOptions* server_options = nullptr; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsNew(&server_options), + "creating server options"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsSetModelRepositoryPath( + server_options, "./models"), + "setting model repository path"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsSetBackendDirectory( + server_options, "/opt/tritonserver/backends"), + "setting backend directory"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsSetRepoAgentDirectory( + server_options, "/opt/tritonserver/repoagents"), + "setting repository agent directory"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsSetStrictModelConfig(server_options, true), + "setting strict model configuration"); + + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerNew(&server_, server_options), "creating server"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsDelete(server_options), + "deleting server options"); + } + + static void TearDownTestSuite() + { + FAIL_TEST_IF_ERR(TRITONSERVER_ServerDelete(server_), "deleting server"); + } + + void SetUp() override + { + ASSERT_TRUE(server_ != nullptr) << "Server has not created"; + // Wait until the server is both live and ready. + size_t health_iters = 0; + while (true) { + bool live, ready; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerIsLive(server_, &live), + "unable to get server liveness"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerIsReady(server_, &ready), + "unable to get server readiness"); + if (live && ready) { + break; + } + + if (++health_iters >= 10) { + FAIL() << "failed to find healthy inference server"; + } + + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + + // Create allocator with common callback + FAIL_TEST_IF_ERR( + TRITONSERVER_ResponseAllocatorNew( + &allocator_, ResponseAlloc, ResponseRelease, + nullptr /* start_fn */), + "creating response allocator"); + + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestNew( + &irequest_, server_, "query", -1 /* model_version */), + "creating inference request"); + + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestSetReleaseCallback( + irequest_, InferRequestComplete, + nullptr /* request_release_userp */), + "setting request release callback"); + + std::vector shape{1}; + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestAddInput( + irequest_, "INPUT", TRITONSERVER_TYPE_UINT8, shape.data(), + shape.size()), + "setting input for the request"); + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestAppendInputData( + irequest_, "INPUT", input_data_.data(), input_data_.size(), + TRITONSERVER_MEMORY_CPU, 0), + "assigning INPUT data"); + + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestSetResponseCallback( + irequest_, allocator_, reinterpret_cast(&output_info_), + InferResponseComplete, reinterpret_cast(&completed_)), + "setting response callback"); + } + + void TearDown() override + { + unsetenv("TEST_ANONYMOUS"); + unsetenv("TEST_BYTE_SIZE"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ResponseAllocatorDelete(allocator_), + "deleting response allocator"); + } + + static TRITONSERVER_Server* server_; + TRITONSERVER_ResponseAllocator* allocator_ = nullptr; + static std::vector input_data_; + TRITONSERVER_InferenceRequest* irequest_ = nullptr; + std::promise completed_; + std::pair, NameMap> output_info_; +}; + +TRITONSERVER_Server* QueryTest::server_ = nullptr; +std::vector QueryTest::input_data_{1}; + +TEST_F(QueryTest, DefaultQuery) +{ + TRITONSERVER_ResponseAllocatorQueryFn_t query_fn = + [](TRITONSERVER_ResponseAllocator* allocator, void* userp, + const char* tensor_name, size_t* byte_size, + TRITONSERVER_MemoryType* memory_type, + int64_t* memory_type_id) -> TRITONSERVER_Error* { + auto& query_tracker = + (reinterpret_cast, NameMap>*>(userp) + ->first); + query_tracker.emplace_back( + tensor_name, byte_size, *memory_type, *memory_type_id); + *memory_type = TRITONSERVER_MEMORY_CPU; + *memory_type_id = 0; + return nullptr; + }; + FAIL_TEST_IF_ERR( + TRITONSERVER_ResponseAllocatorSetQueryFunction(allocator_, query_fn), + "setting response callback"); + + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */), + "running inference"); + + auto err = completed_.get_future().get(); + ASSERT_TRUE(err == nullptr) << "Expect successful inference"; + + // Check query tracker to see if the query function is connected properly + ASSERT_EQ(output_info_.first.size(), size_t(2)); + for (size_t i = 0; i < output_info_.first.size(); ++i) { + const auto& query_info = output_info_.first[i]; + EXPECT_EQ(query_info.has_name_, true); + EXPECT_EQ(query_info.name_, (std::string("OUTPUT") + std::to_string(i))); + EXPECT_EQ(query_info.has_byte_size_, false); + EXPECT_EQ( + query_info.caller_preferred_type_, TRITONSERVER_MEMORY_CPU_PINNED); + EXPECT_EQ(query_info.caller_preferred_id_, 1); + } + + const auto& output_0 = output_info_.second["OUTPUT0"]; + EXPECT_EQ(std::get<0>(output_0), TRITONSERVER_MEMORY_CPU); + EXPECT_EQ(std::get<1>(output_0), int64_t(0)); + EXPECT_EQ(std::get<2>(output_0), size_t(2)); + + const auto& output_1 = output_info_.second["OUTPUT1"]; + EXPECT_EQ(std::get<0>(output_1), TRITONSERVER_MEMORY_CPU); + EXPECT_EQ(std::get<1>(output_1), int64_t(0)); + EXPECT_EQ(std::get<2>(output_1), size_t(2)); +} + +TEST_F(QueryTest, NoQueryFn) +{ + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */), + "running inference"); + + auto err = completed_.get_future().get(); + ASSERT_TRUE(err != nullptr) << "Expect error"; + EXPECT_EQ(TRITONSERVER_ErrorCode(err), TRITONSERVER_ERROR_UNAVAILABLE); + EXPECT_THAT( + TRITONSERVER_ErrorMessage(err), + HasSubstr("Output properties are not available")); +} + +TEST_F(QueryTest, UnnamedQuery) +{ + setenv("TEST_ANONYMOUS", "", 1); + setenv("TEST_BYTE_SIZE", "32", 1); + TRITONSERVER_ResponseAllocatorQueryFn_t query_fn = + [](TRITONSERVER_ResponseAllocator* allocator, void* userp, + const char* tensor_name, size_t* byte_size, + TRITONSERVER_MemoryType* memory_type, + int64_t* memory_type_id) -> TRITONSERVER_Error* { + auto& query_tracker = + (reinterpret_cast, NameMap>*>(userp) + ->first); + query_tracker.emplace_back( + tensor_name, byte_size, *memory_type, *memory_type_id); + // Slightly different setting + *memory_type = TRITONSERVER_MEMORY_GPU; + *memory_type_id = 2; + return nullptr; + }; + FAIL_TEST_IF_ERR( + TRITONSERVER_ResponseAllocatorSetQueryFunction(allocator_, query_fn), + "setting response callback"); + + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */), + "running inference"); + + auto err = completed_.get_future().get(); + ASSERT_TRUE(err == nullptr) << "Expect successful inference"; + + // Check query tracker to see if the query function is connected properly + ASSERT_EQ(output_info_.first.size(), size_t(1)); + for (size_t i = 0; i < output_info_.first.size(); ++i) { + const auto& query_info = output_info_.first[i]; + EXPECT_EQ(query_info.has_name_, false); + EXPECT_EQ(query_info.has_byte_size_, true); + EXPECT_EQ(query_info.byte_size_, size_t(32)); + EXPECT_EQ( + query_info.caller_preferred_type_, TRITONSERVER_MEMORY_CPU_PINNED); + EXPECT_EQ(query_info.caller_preferred_id_, 1); + } + + const auto& output_0 = output_info_.second["OUTPUT0"]; + EXPECT_EQ(std::get<0>(output_0), TRITONSERVER_MEMORY_GPU); + EXPECT_EQ(std::get<1>(output_0), int64_t(2)); + EXPECT_EQ(std::get<2>(output_0), size_t(16)); + + const auto& output_1 = output_info_.second["OUTPUT1"]; + EXPECT_EQ(std::get<0>(output_1), TRITONSERVER_MEMORY_GPU); + EXPECT_EQ(std::get<1>(output_1), int64_t(2)); + EXPECT_EQ(std::get<2>(output_1), size_t(16)); +} + +} // namespace + +int +main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/3rdparty/core-r22.12/src/test/register_api_test.cc b/3rdparty/core-r22.12/src/test/register_api_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5a41bbad98dc312ab2f38042299d25433f3f2a05 --- /dev/null +++ b/3rdparty/core-r22.12/src/test/register_api_test.cc @@ -0,0 +1,905 @@ +// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include +#include "gtest/gtest.h" +#include "triton/core/tritonserver.h" + +namespace { + +#define FAIL_TEST_IF_ERR(X, MSG) \ + do { \ + std::shared_ptr err__((X), TRITONSERVER_ErrorDelete); \ + ASSERT_TRUE((err__ == nullptr)) \ + << "error: " << (MSG) << ": " \ + << TRITONSERVER_ErrorCodeString(err__.get()) << " - " \ + << TRITONSERVER_ErrorMessage(err__.get()); \ + } while (false) + +#define FAIL_TEST_IF_NOT_ERR(X, CODE, ERR_MSG, MSG) \ + do { \ + std::shared_ptr err__((X), TRITONSERVER_ErrorDelete); \ + ASSERT_TRUE((err__ != nullptr)) << "expected error on: " << (MSG); \ + if (err__ != nullptr) { \ + EXPECT_EQ(TRITONSERVER_ErrorCode(err__.get()), (CODE)) << (MSG); \ + EXPECT_STREQ(TRITONSERVER_ErrorMessage(err__.get()), (ERR_MSG)) \ + << (MSG); \ + } \ + } while (false) + +// Test Fixture, this test suit expects the current directory to +// have the following file structure: +// - empty_models (empty directory) +// - models_0 (contain model directory "model_0") +// - models_1 (contain model directories "model_0", "model_1") +// - models_2 (contain model directories "model_0" with config name +// "mapped_name") +class RegisterApiTest : public ::testing::Test { + protected: + void SetUp() override + { + // Create running server object. + TRITONSERVER_ServerOptions* server_options = nullptr; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsNew(&server_options), + "creating server options"); + // Triton expects at least one model repository is set at start, set to + // an empty repository set ModelControlMode to EXPLICIT to avoid attempting + // to load models. + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsSetModelRepositoryPath( + server_options, "empty_models"), + "setting model repository path"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsSetModelControlMode( + server_options, TRITONSERVER_MODEL_CONTROL_EXPLICIT), + "setting model control mode"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerNew(&server_, server_options), "creating server"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsDelete(server_options), + "deleting server options"); + ASSERT_TRUE(server_ != nullptr) << "server not created"; + bool live = false; + for (int i = 10; ((i > 0) && !live); --i) { + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerIsLive(server_, &live), "Is server live"); + } + ASSERT_TRUE(live) << "server not live"; + } + + void TearDown() override + { + FAIL_TEST_IF_ERR(TRITONSERVER_ServerDelete(server_), "deleting server"); + } + + TRITONSERVER_Server* server_ = nullptr; +}; + +TEST_F(RegisterApiTest, Register) +{ + // Request to load "model_0" which should fail + FAIL_TEST_IF_NOT_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_0"), + TRITONSERVER_ERROR_INTERNAL, + "failed to load 'model_0', failed to poll from model repository", + "loading model 'model_0'"); + + // Registering a repository "models_0" where contains "model_0" + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_0", nullptr, 0), + "registering model repository 'models_0'"); + // Request to load "model_0" + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_0"), + "loading model 'model_0'"); + bool ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "model_0", 1, &ready), + "Is 'model_0' v1 ready"); + ASSERT_TRUE(ready) << "Expect 'model_0' v1 to be ready, model directory is " + "'models_0/model_0'"; +} + +TEST_F(RegisterApiTest, RegisterWithMap) +{ + // Registering a repository "models_0" where contains "model_0", but with + // different name mapping + const char* override_name = "name_0"; + std::shared_ptr managed_param( + TRITONSERVER_ParameterNew( + "model_0", TRITONSERVER_PARAMETER_STRING, override_name), + TRITONSERVER_ParameterDelete); + ASSERT_TRUE(managed_param != nullptr) << "failed to create name mapping pair"; + std::vector name_map{managed_param.get()}; + + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_0", name_map.data(), name_map.size()), + "registering model repository 'models_0'"); + + // Request to load "model_0" which should fail + FAIL_TEST_IF_NOT_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_0"), + TRITONSERVER_ERROR_INTERNAL, + "failed to load 'model_0', failed to poll from model repository", + "loading model 'model_0'"); + // Request to load "name_0" + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "name_0"), + "loading model 'name_0'"); + bool ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "name_0", 1, &ready), + "Is 'name_0' v1 ready"); + ASSERT_TRUE(ready) << "Expect 'name_0' v1 to be ready, model directory is " + "'models_0/model_0'"; +} + +TEST_F(RegisterApiTest, RegisterTwice) +{ + // Registering a startup repository + FAIL_TEST_IF_NOT_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "empty_models", nullptr, 0), + TRITONSERVER_ERROR_ALREADY_EXISTS, + "model repository 'empty_models' has already been registered", + "registering model repository 'empty_models'"); +} + +TEST_F(RegisterApiTest, RegisterTwice2) +{ + // Registering the same repository twice + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_0", nullptr, 0), + "registering model repository 'models_0'"); + + FAIL_TEST_IF_NOT_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_0", nullptr, 0), + TRITONSERVER_ERROR_ALREADY_EXISTS, + "model repository 'models_0' has already been registered", + "registering model repository 'models_0'"); +} + +TEST_F(RegisterApiTest, RegisterWithMultiMap) +{ + // Registering a repository "models_0" where contains "model_0", + // and "model_0" is mapped to two different names + std::vector override_names{"name_0", "name_1"}; + std::vector> managed_params; + std::vector name_map; + for (const auto& name : override_names) { + managed_params.emplace_back( + TRITONSERVER_ParameterNew( + "model_0", TRITONSERVER_PARAMETER_STRING, name.c_str()), + TRITONSERVER_ParameterDelete); + ASSERT_TRUE(managed_params.back() != nullptr) + << "failed to create name mapping pair"; + name_map.emplace_back(managed_params.back().get()); + } + + // Such mapping should be allow as it is mapping to unique names + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_0", name_map.data(), name_map.size()), + "registering model repository 'models_0'"); + + // Request to load "name_0" + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "name_0"), + "loading model 'name_0'"); + bool ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "name_0", 1, &ready), + "Is 'name_0' v1 ready"); + ASSERT_TRUE(ready) << "Expect 'name_0' v1 to be ready, model directory is " + "'models_0/model_0'"; + + // Request to load "name_1" + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "name_1"), + "loading model 'name_1'"); + ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "name_1", 1, &ready), + "Is 'name_1' v1 ready"); + ASSERT_TRUE(ready) << "Expect 'name_1' v1 to be ready, model directory is " + "'models_0/model_0'"; +} + +TEST_F(RegisterApiTest, RegisterWithRepeatedMap) +{ + // Registering a repository "models_1" where contains "model_0" and "model_1", + // map "model_0" to "model_1" which creates confliction, however, + // in EXPLICIT mode, mapping lookup will have higher priority than + // repository polling so the confliction will be resolved by always loading + // the model from mapped directory. + std::vector override_names{"model_1"}; + std::vector> managed_params; + std::vector name_map; + managed_params.emplace_back( + TRITONSERVER_ParameterNew( + "model_0", TRITONSERVER_PARAMETER_STRING, override_names[0].c_str()), + TRITONSERVER_ParameterDelete); + ASSERT_TRUE(managed_params.back() != nullptr) + << "failed to create name mapping pair"; + name_map.emplace_back(managed_params.back().get()); + + // Such mapping should be allow as it is mapping to unique names + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_1", name_map.data(), name_map.size()), + "registering model repository 'models_1'"); + + // Request to load "model_1" + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_1"), + "loading model 'model_1'"); + bool ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "model_1", 2, &ready), + "Is 'model_1' ready"); + ASSERT_TRUE(ready) << "Expect 'model_1' v2 to be ready, model directory is " + "'models_1/model_0'"; +} + +TEST_F(RegisterApiTest, RegisterWithRepeatedMap2) +{ + // Registering a repository "models_1" where contains "model_0" and "model_1", + // map both directories to the same name which creates confliction. Different + // from 'RegisterWithRepeatedMap', the confliction within the mapping can't be + // resolved and error should be returend + std::vector dir_names{"model_0", "model_1"}; + std::vector> managed_params; + std::vector name_map; + for (const auto& name : dir_names) { + managed_params.emplace_back( + TRITONSERVER_ParameterNew( + name.c_str(), TRITONSERVER_PARAMETER_STRING, "name_0"), + TRITONSERVER_ParameterDelete); + ASSERT_TRUE(managed_params.back() != nullptr) + << "failed to create name mapping pair"; + name_map.emplace_back(managed_params.back().get()); + } + + // Register should fail + FAIL_TEST_IF_NOT_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_1", name_map.data(), name_map.size()), + TRITONSERVER_ERROR_INVALID_ARG, + "failed to register 'models_1', there is a conflicting mapping for " + "'name_0'", + "registering model repository 'models_1'"); +} + +TEST_F(RegisterApiTest, RegisterMulti) +{ + // Registering repository "models_0" and "model_1" without mappings, + // there are duplicate models but it won't be checked until load + std::vector name_map; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_0", name_map.data(), name_map.size()), + "registering model repository 'models_0'"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_1", name_map.data(), name_map.size()), + "registering model repository 'models_1'"); + + // Request to load "model_0" which should fail + FAIL_TEST_IF_NOT_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_0"), + TRITONSERVER_ERROR_INTERNAL, + "failed to load 'model_0', failed to poll from model repository", + "loading model 'model_0'"); + // Request to load "model_1" + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_1"), + "loading model 'model_1'"); + bool ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "model_1", 3, &ready), + "Is 'model_1' ready"); + ASSERT_TRUE(ready) << "Expect 'model_1' v3 to be ready, model directory is " + "'models_1/model_1'"; +} + +TEST_F(RegisterApiTest, RegisterMultiWithMap) +{ + // Registering repository "models_0" and "models_1" without mappings, + // there are duplicate models but we provides a "override" map for "models_0", + // from "model_0" to "model_0" which sets priority to resolve the conflict. + std::vector override_names{"model_0"}; + std::vector> managed_params; + std::vector name_map; + managed_params.emplace_back( + TRITONSERVER_ParameterNew( + "model_0", TRITONSERVER_PARAMETER_STRING, override_names[0].c_str()), + TRITONSERVER_ParameterDelete); + ASSERT_TRUE(managed_params.back() != nullptr) + << "failed to create name mapping pair"; + name_map.emplace_back(managed_params.back().get()); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_0", name_map.data(), name_map.size()), + "registering model repository 'models_0'"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_1", nullptr, 0), + "registering model repository 'models_1'"); + + // Request to load "model_0", "model_1" + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_0"), + "loading model 'model_0'"); + bool ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "model_0", 1, &ready), + "Is 'model_0' ready"); + ASSERT_TRUE(ready) << "Expect 'model_0' v1 to be ready, model directory is " + "'models_0/model_0'"; + + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_1"), + "loading model 'model_1'"); + ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "model_1", 3, &ready), + "Is 'model_1' ready"); + ASSERT_TRUE(ready) << "Expect 'model_1' v3 to be ready, model directory is " + "'models_1/model_1'"; +} + +TEST_F(RegisterApiTest, RegisterMultiWithMap2) +{ + // Registering repository "models_0" and "model_1s", + // there are duplicate models but we provides a map for "models_1" + // so they all have different name. + std::vector override_names{"model_2"}; + std::vector> managed_params; + std::vector name_map; + managed_params.emplace_back( + TRITONSERVER_ParameterNew( + "model_0", TRITONSERVER_PARAMETER_STRING, override_names[0].c_str()), + TRITONSERVER_ParameterDelete); + ASSERT_TRUE(managed_params.back() != nullptr) + << "failed to create name mapping pair"; + name_map.emplace_back(managed_params.back().get()); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_0", nullptr, 0), + "registering model repository 'models_0'"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_1", name_map.data(), name_map.size()), + "registering model repository 'models_1'"); + + // Request to load "model_0", "model_1", "model_2" + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_0"), + "loading model 'model_0'"); + bool ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "model_0", 1, &ready), + "Is 'model_0' ready"); + ASSERT_TRUE(ready) << "Expect 'model_0' v1 to be ready, model directory is " + "'models_0/model_0'"; + + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_1"), + "loading model 'model_1'"); + ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "model_1", 3, &ready), + "Is 'model_1' ready"); + ASSERT_TRUE(ready) << "Expect 'model_1' v3 to be ready, model directory is " + "'models_1/model_1'"; + + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_2"), + "loading model 'model_2'"); + ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "model_2", 2, &ready), + "Is 'model_2' ready"); + ASSERT_TRUE(ready) << "Expect 'model_2' v2 to be ready, model directory is " + "'models_1/model_0'"; +} + +TEST_F(RegisterApiTest, RegisterMultiWithMap3) +{ + // Registering repository "models_0" and "model_1s", + // there are duplicate models but we provides a map for both + // "models_0" and "models_1" so they all have different name. + std::vector override_names{"name_0", "name_1"}; + std::vector> managed_params; + for (const auto& name : override_names) { + managed_params.emplace_back( + TRITONSERVER_ParameterNew( + "model_0", TRITONSERVER_PARAMETER_STRING, name.c_str()), + TRITONSERVER_ParameterDelete); + ASSERT_TRUE(managed_params.back() != nullptr) + << "failed to create name mapping pair"; + } + std::vector models_0_map{ + managed_params[0].get()}; + std::vector models_1_map{ + managed_params[1].get()}; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_0", models_0_map.data(), models_0_map.size()), + "registering model repository 'models_0'"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_1", models_1_map.data(), models_1_map.size()), + "registering model repository 'models_1'"); + + // Request to load "model_0", "model_1", "model_2" + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "name_0"), + "loading model 'name_0'"); + bool ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "name_0", 1, &ready), + "Is 'name_0' ready"); + ASSERT_TRUE(ready) << "Expect 'name_0' v1 to be ready, model directory is " + "'models_0/model_0'"; + + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "name_1"), + "loading model 'name_1'"); + ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "name_1", 2, &ready), + "Is 'name_1' ready"); + ASSERT_TRUE(ready) << "Expect 'name_1' v2 to be ready, model directory is " + "'models_1/model_0'"; + + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_1"), + "loading model 'model_1'"); + ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "model_1", 3, &ready), + "Is 'model_1' ready"); + ASSERT_TRUE(ready) << "Expect 'model_1' v3 to be ready, model directory is " + "'models_1/model_1'"; +} + +TEST_F(RegisterApiTest, RegisterNonExistingRepo) +{ + // Register should fail + FAIL_TEST_IF_NOT_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "unknown_repo", nullptr, 0), + TRITONSERVER_ERROR_INVALID_ARG, + "failed to register 'unknown_repo', repository not found", + "registering model repository 'unknown_repo'"); +} + + +TEST_F(RegisterApiTest, UnregisterInvalidRepo) +{ + FAIL_TEST_IF_NOT_ERR( + TRITONSERVER_ServerUnregisterModelRepository(server_, "unknown_repo"), + TRITONSERVER_ERROR_INVALID_ARG, + "failed to unregister 'unknown_repo', repository not found", + "unregistering model repository 'unknown_repo'"); +} + +TEST_F(RegisterApiTest, Unregister) +{ + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerUnregisterModelRepository(server_, "empty_models"), + "unregistering model repository 'empty_models'"); +} + +TEST_F(RegisterApiTest, UnregisterTwice) +{ + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerUnregisterModelRepository(server_, "empty_models"), + "unregistering model repository 'empty_models'"); + FAIL_TEST_IF_NOT_ERR( + TRITONSERVER_ServerUnregisterModelRepository(server_, "empty_models"), + TRITONSERVER_ERROR_INVALID_ARG, + "failed to unregister 'empty_models', repository not found", + "unregistering model repository 'empty_models'"); +} + +TEST_F(RegisterApiTest, UnregisterWithLoadedModel) +{ + // Registering a repository "models_0" where contains "model_0" + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_0", nullptr, 0), + "registering model repository 'models_0'"); + // Request to load "model_0" + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_0"), + "loading model 'model_0'"); + + // Unregister and the model should still be loaded + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerUnregisterModelRepository(server_, "models_0"), + "unregistering model repository 'models_0'"); + + bool ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "model_0", 1, &ready), + "Is 'model_0' ready"); + ASSERT_TRUE(ready) << "Expect 'model_0' v1 to be ready, model directory is " + "'models_0/model_0'"; + + // Request to load "model_0" which should fail + FAIL_TEST_IF_NOT_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_0"), + TRITONSERVER_ERROR_INTERNAL, + "failed to load 'model_0', failed to poll from model repository", + "loading model 'model_0'"); +} + +TEST_F(RegisterApiTest, MultiRegister) +{ + // Register / unregister a repository "models_0" + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_0", nullptr, 0), + "registering model repository 'models_0'"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerUnregisterModelRepository(server_, "models_0"), + "unregistering model repository 'models_0'"); + // Register / unregister "models_0" again + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_0", nullptr, 0), + "registering model repository 'models_0'"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerUnregisterModelRepository(server_, "models_0"), + "unregistering model repository 'models_0'"); +} + +TEST_F(RegisterApiTest, RegisterMulti2) +{ + // Registering repository "models_0" and "model_1" without mappings, + // there are duplicate models but it won't be checked until load + std::vector name_map; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_0", name_map.data(), name_map.size()), + "registering model repository 'models_0'"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_1", name_map.data(), name_map.size()), + "registering model repository 'models_1'"); + + // Request to load "model_0" which should fail + FAIL_TEST_IF_NOT_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_0"), + TRITONSERVER_ERROR_INTERNAL, + "failed to load 'model_0', failed to poll from model repository", + "loading model 'model_0'"); + // Request to load "model_1" + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_1"), + "loading model 'model_1'"); + + // Unregister one of the repos and 'model_0' can be loaded as there is no + // confliction + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerUnregisterModelRepository(server_, "models_1"), + "unregistering model repository 'models_1'"); + // Request to load "model_0" + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_0"), + "loading model 'model_0'"); + + bool ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "model_0", 1, &ready), + "Is 'model_0' ready"); + ASSERT_TRUE(ready) << "Expect 'model_0' v1 to be ready, model directory is " + "'models_0/model_0'"; + + ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "model_1", 3, &ready), + "Is 'model_1' ready"); + ASSERT_TRUE(ready) << "Expect 'model_1' v3 to be ready, model directory is " + "'models_1/model_1'"; +} + +TEST_F(RegisterApiTest, DifferentMapping) +{ + // With register and unregister, user can update a mapping for specific repo. + std::vector override_names{"name_0"}; + std::vector> managed_params; + std::vector name_map; + managed_params.emplace_back( + TRITONSERVER_ParameterNew( + "model_0", TRITONSERVER_PARAMETER_STRING, override_names[0].c_str()), + TRITONSERVER_ParameterDelete); + ASSERT_TRUE(managed_params.back() != nullptr) + << "failed to create name mapping pair"; + name_map.emplace_back(managed_params.back().get()); + + // First register without mapping + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_0", nullptr, 0), + "registering model repository 'models_0'"); + // Request to load "model_0" + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_0"), + "loading model 'model_0'"); + + // Re-register with mapping + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerUnregisterModelRepository(server_, "models_0"), + "unregistering model repository 'models_0'"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_0", name_map.data(), name_map.size()), + "registering model repository 'models_0'"); + // Request to load "model_0" will fail, but load "name_0" is okay + FAIL_TEST_IF_NOT_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_0"), + TRITONSERVER_ERROR_INTERNAL, + "failed to load 'model_0', failed to poll from model repository", + "loading model 'model_0'"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "name_0"), + "loading model 'name_0'"); + + bool ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "name_0", 1, &ready), + "Is 'name_0' ready"); + ASSERT_TRUE(ready) << "Expect 'name_0' v1 to be ready, model directory is " + "'models_0/model_0'"; + + // Verify that model_0 still exists in-memory + ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "model_0", 1, &ready), + "Is 'model_0' ready"); + ASSERT_TRUE(ready) << "Expect 'model_0' v1 to be ready, model directory is " + "'models_0/model_0'"; +} + +TEST_F(RegisterApiTest, CorrectIndex) +{ + // Registering a repository "models_0" where contains "model_0", but with + // different name mapping + const char* override_name = "name_0"; + std::shared_ptr managed_param( + TRITONSERVER_ParameterNew( + "model_0", TRITONSERVER_PARAMETER_STRING, override_name), + TRITONSERVER_ParameterDelete); + ASSERT_TRUE(managed_param != nullptr) << "failed to create name mapping pair"; + std::vector name_map{managed_param.get()}; + + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_0", name_map.data(), name_map.size()), + "registering model repository 'models_0'"); + + // Request to load "model_0" which should fail + FAIL_TEST_IF_NOT_ERR( + TRITONSERVER_ServerLoadModel(server_, "model_0"), + TRITONSERVER_ERROR_INTERNAL, + "failed to load 'model_0', failed to poll from model repository", + "loading model 'model_0'"); + // Request to load "name_0" + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerLoadModel(server_, "name_0"), + "loading model 'name_0'"); + bool ready = false; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIsReady(server_, "name_0", 1, &ready), + "Is 'name_0' v1 ready"); + ASSERT_TRUE(ready) << "Expect 'name_0' v1 to be ready, model directory is " + "'models_0/model_0'"; + + TRITONSERVER_Message* repository_index; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIndex(server_, 1, &repository_index), + "checking model indexes"); + const char* base = nullptr; + size_t byte_size = 0; + FAIL_TEST_IF_ERR( + TRITONSERVER_MessageSerializeToJson(repository_index, &base, &byte_size), + "serializing index to Json"); + const std::string search_msg = + "[{\"name\":\"name_0\",\"version\":\"1\",\"state\":\"READY\"}]"; + const std::string serialized_index(base, byte_size); + EXPECT_EQ(serialized_index, search_msg) + << "Returned index does not equal expected index"; +} + +TEST_F(RegisterApiTest, CorrectIndexNotLoaded) +{ + // Registering a repository "models_0" where contains "model_0", but with + // different name mapping + const char* override_name = "name_0"; + std::shared_ptr managed_param( + TRITONSERVER_ParameterNew( + "model_0", TRITONSERVER_PARAMETER_STRING, override_name), + TRITONSERVER_ParameterDelete); + ASSERT_TRUE(managed_param != nullptr) << "failed to create name mapping pair"; + std::vector name_map{managed_param.get()}; + + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "models_0", name_map.data(), name_map.size()), + "registering model repository 'models_0'"); + + TRITONSERVER_Message* repository_index; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerModelIndex(server_, 0, &repository_index), + "checking model indexes"); + const char* base = nullptr; + size_t byte_size = 0; + FAIL_TEST_IF_ERR( + TRITONSERVER_MessageSerializeToJson(repository_index, &base, &byte_size), + "serializing index to Json"); + const std::string search_msg = "[{\"name\":\"name_0\"}]"; + const std::string serialized_index(base, byte_size); + EXPECT_EQ(serialized_index, search_msg) + << "Returned index does not equal expected index"; +} + +// // Test Fixture that runs server with POLLING mode +class PollingRegisterApiTest : public ::testing::Test { + protected: + void SetUp() override + { + // Create running server object. + TRITONSERVER_ServerOptions* server_options = nullptr; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsNew(&server_options), + "creating server options"); + // Triton expects at least one model repository is set at start, set to + // an empty repository set ModelControlMode to EXPLICIT to avoid attempting + // to load models. + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsSetModelRepositoryPath( + server_options, "empty_models"), + "setting model repository path"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsSetModelControlMode( + server_options, TRITONSERVER_MODEL_CONTROL_POLL), + "setting model control mode"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerNew(&server_, server_options), "creating server"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsDelete(server_options), + "deleting server options"); + ASSERT_TRUE(server_ != nullptr) << "server not created"; + bool live = false; + for (int i = 10; ((i > 0) && !live); --i) { + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerIsLive(server_, &live), "Is server live"); + } + ASSERT_TRUE(live) << "server not live"; + } + + void TearDown() override + { + FAIL_TEST_IF_ERR(TRITONSERVER_ServerDelete(server_), "deleting server"); + } + + TRITONSERVER_Server* server_ = nullptr; +}; + +TEST_F(PollingRegisterApiTest, unsupport) +{ + FAIL_TEST_IF_NOT_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "empty_models", nullptr, 0), + TRITONSERVER_ERROR_UNSUPPORTED, + "repository registration is not allowed if model control mode is not " + "EXPLICIT", + "registering model repository 'empty_models'"); + FAIL_TEST_IF_NOT_ERR( + TRITONSERVER_ServerUnregisterModelRepository(server_, "empty_models"), + TRITONSERVER_ERROR_UNSUPPORTED, + "repository unregistration is not allowed if model control mode is not " + "EXPLICIT", + "unregistering model repository 'empty_models'"); +} + +// Test Fixture that runs server with NONE mode +class NoneRegisterApiTest : public ::testing::Test { + protected: + void SetUp() override + { + // Create running server object. + TRITONSERVER_ServerOptions* server_options = nullptr; + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsNew(&server_options), + "creating server options"); + // Triton expects at least one model repository is set at start, set to + // an empty repository set ModelControlMode to EXPLICIT to avoid attempting + // to load models. + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsSetModelRepositoryPath( + server_options, "empty_models"), + "setting model repository path"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsSetModelControlMode( + server_options, TRITONSERVER_MODEL_CONTROL_NONE), + "setting model control mode"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerNew(&server_, server_options), "creating server"); + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerOptionsDelete(server_options), + "deleting server options"); + ASSERT_TRUE(server_ != nullptr) << "server not created"; + bool live = false; + for (int i = 10; ((i > 0) && !live); --i) { + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerIsLive(server_, &live), "Is server live"); + } + ASSERT_TRUE(live) << "server not live"; + } + + void TearDown() override + { + FAIL_TEST_IF_ERR(TRITONSERVER_ServerDelete(server_), "deleting server"); + } + + TRITONSERVER_Server* server_ = nullptr; +}; + +TEST_F(NoneRegisterApiTest, unsupport) +{ + FAIL_TEST_IF_NOT_ERR( + TRITONSERVER_ServerRegisterModelRepository( + server_, "empty_models", nullptr, 0), + TRITONSERVER_ERROR_UNSUPPORTED, + "repository registration is not allowed if model control mode is not " + "EXPLICIT", + "registering model repository 'empty_models'"); + FAIL_TEST_IF_NOT_ERR( + TRITONSERVER_ServerUnregisterModelRepository(server_, "empty_models"), + TRITONSERVER_ERROR_UNSUPPORTED, + "repository unregistration is not allowed if model control mode is not " + "EXPLICIT", + "unregistering model repository 'empty_models'"); +} + +} // namespace + +int +main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/3rdparty/core-r22.12/src/test/repo_agent_test.cc b/3rdparty/core-r22.12/src/test/repo_agent_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..76d72e22db69cb336d4caed4c6f4eb5a94fc53dc --- /dev/null +++ b/3rdparty/core-r22.12/src/test/repo_agent_test.cc @@ -0,0 +1,2365 @@ +// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#include "gtest/gtest.h" + +#include +#include +#include +#include +#include +#include +#include +#include "filesystem.h" +#include "repo_agent.h" +#include "server_message.h" +#include "shared_library.h" + +namespace tc = triton::core; + +namespace { + +// +// Duplication of TRITONSERVER_Error implementation +// +class TritonServerError { + public: + static TRITONSERVER_Error* Create( + TRITONSERVER_Error_Code code, const char* msg); + static TRITONSERVER_Error* Create(const tc::Status& status); + + TRITONSERVER_Error_Code Code() const { return code_; } + const std::string& Message() const { return msg_; } + + private: + TritonServerError(TRITONSERVER_Error_Code code, const std::string& msg) + : code_(code), msg_(msg) + { + } + TritonServerError(TRITONSERVER_Error_Code code, const char* msg) + : code_(code), msg_(msg) + { + } + + TRITONSERVER_Error_Code code_; + const std::string msg_; +}; + +TRITONSERVER_Error* +TritonServerError::Create(TRITONSERVER_Error_Code code, const char* msg) +{ + return reinterpret_cast( + new TritonServerError(code, msg)); +} + +TRITONSERVER_Error* +TritonServerError::Create(const tc::Status& status) +{ + // If 'status' is success then return nullptr as that indicates + // success + if (status.IsOk()) { + return nullptr; + } + + return Create( + tc::StatusCodeToTritonCode(status.StatusCode()), + status.Message().c_str()); +} + +class MockSharedLibraryHandle { + public: + bool AddEntryPoint(const std::string& name, void* fn) + { + auto it = entry_points_.find(name); + if (it == entry_points_.end()) { + entry_points_.emplace(name, fn).second; + return true; + } else { + it->second = fn; + return false; + } + } + + bool GetEntryPoint(const std::string& name, void** fn) + { + auto it = entry_points_.find(name); + if (it != entry_points_.end()) { + *fn = it->second; + return true; + } + return false; + } + + private: + std::map entry_points_; +}; + +static std::map global_mock_agents; + +} // namespace + +#ifdef __cplusplus +extern "C" { +#endif + +TRITONSERVER_Error* +TRITONSERVER_ErrorNew(TRITONSERVER_Error_Code code, const char* msg) +{ + return reinterpret_cast( + TritonServerError::Create(code, msg)); +} + +void +TRITONSERVER_ErrorDelete(TRITONSERVER_Error* error) +{ + TritonServerError* lerror = reinterpret_cast(error); + delete lerror; +} + +TRITONSERVER_Error_Code +TRITONSERVER_ErrorCode(TRITONSERVER_Error* error) +{ + TritonServerError* lerror = reinterpret_cast(error); + return lerror->Code(); +} + +const char* +TRITONSERVER_ErrorCodeString(TRITONSERVER_Error* error) +{ + TritonServerError* lerror = reinterpret_cast(error); + return tc::Status::CodeString(tc::TritonCodeToStatusCode(lerror->Code())); +} + +const char* +TRITONSERVER_ErrorMessage(TRITONSERVER_Error* error) +{ + TritonServerError* lerror = reinterpret_cast(error); + return lerror->Message().c_str(); +} + +// +// TRITONSERVER_Message +// +TRITONSERVER_Error* +TRITONSERVER_MessageNewFromSerializedJson( + TRITONSERVER_Message** message, const char* base, size_t byte_size) +{ + *message = reinterpret_cast( + new tc::TritonServerMessage({base, byte_size})); + return nullptr; +} + +TRITONSERVER_Error* +TRITONSERVER_MessageSerializeToJson( + TRITONSERVER_Message* message, const char** base, size_t* byte_size) +{ + tc::TritonServerMessage* lmessage = + reinterpret_cast(message); + lmessage->Serialize(base, byte_size); + return nullptr; // Success +} + +#ifdef __cplusplus +} +#endif + +namespace triton { namespace core { + +Status +SharedLibrary::Acquire(std::unique_ptr* slib) +{ + slib->reset(new SharedLibrary()); + return Status::Success; +} + +SharedLibrary::~SharedLibrary() {} +Status +SharedLibrary::SetLibraryDirectory(const std::string& path) +{ + return Status::Success; +} +Status +SharedLibrary::ResetLibraryDirectory() +{ + return Status::Success; +} +Status +SharedLibrary::OpenLibraryHandle(const std::string& path, void** handle) +{ + auto it = global_mock_agents.find(path); + if (it != global_mock_agents.end()) { + *handle = reinterpret_cast(&it->second); + return Status::Success; + } + return Status( + Status::Code::NOT_FOUND, + "unable to load shared library: mock shared library is not set for " + "path " + + path); +} + +Status +SharedLibrary::CloseLibraryHandle(void* handle) +{ + for (auto& global_mock_agent : global_mock_agents) { + if (reinterpret_cast(&global_mock_agent.second) == handle) { + return Status::Success; + } + } + return Status( + Status::Code::NOT_FOUND, + "unable to unload shared library: handle does not matach any mock shared " + "library"); +} + +Status +SharedLibrary::GetEntrypoint( + void* handle, const std::string& name, const bool optional, void** fn) +{ + auto mock_agent = reinterpret_cast(handle); + bool found = mock_agent->GetEntryPoint(name, fn); + if (!optional && !found) { + return Status( + Status::Code::NOT_FOUND, + "unable to find required entrypoint '" + name + "' in shared library"); + } + return Status::Success; +} + +}} // namespace triton::core + +namespace { + +class TritonRepoAgentTest : public ::testing::Test { + protected: + void TearDown() override { global_mock_agents.clear(); } +}; + +TEST_F(TritonRepoAgentTest, Create) +{ + // Set up agent with only action function defined, check agent properties + tc::TritonRepoAgent::TritonRepoAgentModelActionFn_t CheckNameModelActionFn = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ActionType action_type) -> TRITONSERVER_Error* { + auto lagent = reinterpret_cast(agent); + EXPECT_EQ(lagent->Name(), "minimal_agent") + << "Expect action function is called with minimal agent"; + return nullptr; + }; + auto agent_handle = MockSharedLibraryHandle(); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_ModelAction", + reinterpret_cast(CheckNameModelActionFn)); + global_mock_agents.emplace("minimal_agent_path", agent_handle); + + std::shared_ptr minimal_agent; + auto status = tc::TritonRepoAgent::Create( + "minimal_agent", "minimal_agent_path", &minimal_agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + + ASSERT_TRUE(minimal_agent->AgentModelActionFn() != nullptr) + << "Expect action function is provided"; + EXPECT_TRUE(minimal_agent->AgentModelInitFn() == nullptr) + << "Unexpect model init function is provided"; + EXPECT_TRUE(minimal_agent->AgentModelFiniFn() == nullptr) + << "Unexpect model fini function is provided"; + + auto err = minimal_agent->AgentModelActionFn()( + reinterpret_cast(minimal_agent.get()), nullptr, + TRITONREPOAGENT_ACTION_LOAD); + EXPECT_TRUE(err == nullptr) << "Expect successful action function invocation"; +} + +TEST_F(TritonRepoAgentTest, CreateFailInvalidSharedLibrary) +{ + // Passing a agent path that is not in global_mock_agents to + // simulate failure on opening shared library handle + std::shared_ptr invalid_agent; + auto status = tc::TritonRepoAgent::Create( + "invalid_agent", "invalid_agent_path", &invalid_agent); + ASSERT_FALSE(status.IsOk()) << "Unexpect successful agent creation"; + EXPECT_NE( + status.Message().find("unable to load shared library"), std::string::npos) + << "Unexpect error message: '" << status.Message() + << "', expect 'unable to load shared library...'"; +} + +TEST_F(TritonRepoAgentTest, CreateFailMissingEndpoint) +{ + // Set up agent with nothing defined + auto agent_handle = MockSharedLibraryHandle(); + global_mock_agents.emplace("invalid_agent_path", agent_handle); + + std::shared_ptr invalid_agent; + auto status = tc::TritonRepoAgent::Create( + "invalid_agent", "invalid_agent_path", &invalid_agent); + ASSERT_FALSE(status.IsOk()) << "Unexpect successful agent creation"; + EXPECT_NE( + status.Message().find("unable to find required entrypoint"), + std::string::npos) + << "Unexpect error message: '" << status.Message() + << "', expect 'unable to find required entrypoint...'"; +} + +TEST_F(TritonRepoAgentTest, Lifecycle) +{ + // Set up agent with init / fini function defined + tc::TritonRepoAgent::TritonRepoAgentInitFn_t InitFn = + [](TRITONREPOAGENT_Agent* agent) -> TRITONSERVER_Error* { + auto lagent = reinterpret_cast(agent); + EXPECT_TRUE(lagent->State() == nullptr) + << "Expect agent state is not set before initialization"; + bool* state = new bool(false); + lagent->SetState(reinterpret_cast(state)); + return nullptr; + }; + tc::TritonRepoAgent::TritonRepoAgentFiniFn_t FiniFn = + [](TRITONREPOAGENT_Agent* agent) -> TRITONSERVER_Error* { + auto lagent = reinterpret_cast(agent); + bool* state = reinterpret_cast(lagent->State()); + EXPECT_TRUE(state != nullptr) << "Expect agent state is set"; + EXPECT_TRUE(*state) << "Expect state is set to true"; + delete state; + return nullptr; + }; + tc::TritonRepoAgent::TritonRepoAgentModelActionFn_t ActionFn = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ActionType action_type) -> TRITONSERVER_Error* { + auto lagent = reinterpret_cast(agent); + bool* state = reinterpret_cast(lagent->State()); + EXPECT_TRUE(state != nullptr) << "Expect agent state is set"; + EXPECT_FALSE(*state) << "Expect state is set to false"; + *state = true; + return nullptr; + }; + auto agent_handle = MockSharedLibraryHandle(); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_Initialize", reinterpret_cast(InitFn)); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_Finalize", reinterpret_cast(FiniFn)); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_ModelAction", reinterpret_cast(ActionFn)); + global_mock_agents.emplace("agent_path", agent_handle); + + std::shared_ptr agent; + auto status = tc::TritonRepoAgent::Create("agent", "agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + + ASSERT_TRUE(agent->AgentModelActionFn() != nullptr) + << "Expect action function is provided"; + EXPECT_TRUE(agent->AgentModelInitFn() == nullptr) + << "Unexpect model init function is provided"; + EXPECT_TRUE(agent->AgentModelFiniFn() == nullptr) + << "Unexpect model fini function is provided"; + + auto err = agent->AgentModelActionFn()( + reinterpret_cast(agent.get()), nullptr, + TRITONREPOAGENT_ACTION_LOAD); + EXPECT_TRUE(err == nullptr) << "Expect successful action function invocation"; + // Cause destructor to be called + agent.reset(); +} + +TEST_F(TritonRepoAgentTest, ModelLifecycle) +{ + // Set up agent with model init / fini function defined + tc::TritonRepoAgent::TritonRepoAgentModelInitFn_t InitFn = + [](TRITONREPOAGENT_Agent* agent, + TRITONREPOAGENT_AgentModel* model) -> TRITONSERVER_Error* { + auto lmodel_state = + reinterpret_cast*, std::future*>*>( + model); + lmodel_state->first->set_value(); + return nullptr; + }; + tc::TritonRepoAgent::TritonRepoAgentModelFiniFn_t FiniFn = + [](TRITONREPOAGENT_Agent* agent, + TRITONREPOAGENT_AgentModel* model) -> TRITONSERVER_Error* { + auto lmodel_state = + reinterpret_cast*, std::future*>*>( + model); + lmodel_state->second->get(); + return nullptr; + }; + tc::TritonRepoAgent::TritonRepoAgentModelActionFn_t ActionFn = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ActionType action_type) -> TRITONSERVER_Error* { + auto lmodel_state = + reinterpret_cast*, std::future*>*>( + model); + EXPECT_TRUE(lmodel_state->second->valid()) << "Expect promise value is set"; + return nullptr; + }; + auto agent_handle = MockSharedLibraryHandle(); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_ModelInitialize", reinterpret_cast(InitFn)); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_ModelFinalize", reinterpret_cast(FiniFn)); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_ModelAction", reinterpret_cast(ActionFn)); + global_mock_agents.emplace("agent_path", agent_handle); + + std::shared_ptr agent; + auto status = tc::TritonRepoAgent::Create("agent", "agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + + ASSERT_TRUE(agent->AgentModelActionFn() != nullptr) + << "Expect action function is provided"; + ASSERT_TRUE(agent->AgentModelInitFn() != nullptr) + << "Expect model init function is provided"; + ASSERT_TRUE(agent->AgentModelFiniFn() != nullptr) + << "Expect model fini function is provided"; + + std::promise p; + auto f = p.get_future(); + auto model_state = std::make_pair(&p, &f); + // Simulate the model lifecycle + auto err = agent->AgentModelInitFn()( + reinterpret_cast(agent.get()), + reinterpret_cast(&model_state)); + EXPECT_TRUE(err == nullptr) + << "Expect successful model init function invocation"; + err = agent->AgentModelActionFn()( + reinterpret_cast(agent.get()), + reinterpret_cast(&model_state), + TRITONREPOAGENT_ACTION_LOAD); + EXPECT_TRUE(err == nullptr) << "Expect successful action function invocation"; + err = agent->AgentModelFiniFn()( + reinterpret_cast(agent.get()), + reinterpret_cast(&model_state)); + EXPECT_TRUE(err == nullptr) + << "Expect successful model fini function invocation"; + EXPECT_FALSE(f.valid()) << "Expect future value is retrieved"; +} + +class TritonRepoAgentManagerTest : public ::testing::Test { + public: + static size_t agent_init_counter_; + static size_t agent_fini_counter_; + + protected: + void SetUp() override + { + // Set up agent with init / fini function defined + tc::TritonRepoAgent::TritonRepoAgentInitFn_t InitFn = + [](TRITONREPOAGENT_Agent* agent) -> TRITONSERVER_Error* { + agent_init_counter_++; + return nullptr; + }; + tc::TritonRepoAgent::TritonRepoAgentFiniFn_t FiniFn = + [](TRITONREPOAGENT_Agent* agent) -> TRITONSERVER_Error* { + agent_fini_counter_++; + return nullptr; + }; + tc::TritonRepoAgent::TritonRepoAgentModelActionFn_t ActionFn = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ActionType action_type) + -> TRITONSERVER_Error* { return nullptr; }; + auto agent_handle = MockSharedLibraryHandle(); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_Initialize", reinterpret_cast(InitFn)); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_Finalize", reinterpret_cast(FiniFn)); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_ModelAction", reinterpret_cast(ActionFn)); + + // Reserve valid shared library paths because manager searches the libraries + // via the FileSystem API + const tc::FileSystemType type = tc::FileSystemType::LOCAL; + auto status = tc::MakeTemporaryDirectory(type, &root_agent_path_); + ASSERT_TRUE(status.IsOk()) << "TritonRepoAgentManagerTest set up failed: " + "create temporary directory: " + << status.AsString(); + // FIXME make the following platform independent + global_agent_path_ = tc::JoinPath({root_agent_path_, "global"}); + int err = mkdir( + global_agent_path_.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH); + ASSERT_EQ(err, 0) << "TritonRepoAgentManagerTest set up failed: create " + "global agent directory: " + << err; + const std::set agent_names{"global_agent"}; + for (const auto& agent_name : agent_names) { + auto global_path_to_agent = + tc::JoinPath({global_agent_path_, agent_name}); + auto global_agent = tc::JoinPath( + {global_path_to_agent, tc::TritonRepoAgentLibraryName(agent_name)}); + err = mkdir( + global_path_to_agent.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH); + ASSERT_EQ(err, 0) << "TritonRepoAgentManagerTest set up failed: create " + "global agent directory: " + << err; + std::ofstream global_agent_file(global_agent); + global_mock_agents.emplace(global_agent, agent_handle); + } + status = + tc::TritonRepoAgentManager::SetGlobalSearchPath(global_agent_path_); + ASSERT_TRUE(status.IsOk()) << "TritonRepoAgentManagerTest set up failed: " + "create temporary directory: " + << status.AsString(); + } + void TearDown() override + { + agent_init_counter_ = 0; + agent_fini_counter_ = 0; + if (!root_agent_path_.empty()) { + // tc::DeleteDirectory(root_agent_path_); + } + global_mock_agents.clear(); + } + + std::string root_agent_path_; + std::string global_agent_path_; + std::string local_agent_path_; +}; +size_t TritonRepoAgentManagerTest::agent_init_counter_ = 0; +size_t TritonRepoAgentManagerTest::agent_fini_counter_ = 0; + +TEST_F(TritonRepoAgentManagerTest, CreateFailureFileNotExist) +{ + // Passing a agent path that is not in global_mock_agents to + // simulate failure on opening shared library handle + std::shared_ptr invalid_agent; + auto status = tc::TritonRepoAgentManager::CreateAgent( + "invalid_agent_name", &invalid_agent); + ASSERT_FALSE(status.IsOk()) << "Unexpect successful agent creation"; + EXPECT_NE(status.Message().find("unable to find"), std::string::npos) + << "Unexpect error message: '" << status.Message() + << "', expect 'unable to find...'"; +} + +TEST_F(TritonRepoAgentManagerTest, CreateGlobalAgent) +{ + std::shared_ptr agent; + auto status = tc::TritonRepoAgentManager::CreateAgent("global_agent", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation" << status.AsString(); + agent.reset(); + EXPECT_EQ(agent_init_counter_, (size_t)1) << "Expect 1 agent initialization"; + EXPECT_EQ(agent_fini_counter_, (size_t)1) << "Expect 1 agent finalization"; +} + +TEST_F(TritonRepoAgentManagerTest, AgentPersistence) +{ + std::shared_ptr agent1; + std::shared_ptr agent2; + auto status = + tc::TritonRepoAgentManager::CreateAgent("global_agent", &agent1); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation" << status.AsString(); + EXPECT_EQ(agent_init_counter_, (size_t)1) << "Expect 1 agent initialization"; + EXPECT_EQ(agent_fini_counter_, (size_t)0) << "Expect 0 agent finalization"; + + status = tc::TritonRepoAgentManager::CreateAgent("global_agent", &agent2); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation" << status.AsString(); + EXPECT_EQ(agent_init_counter_, (size_t)1) << "Expect 1 agent initialization"; + EXPECT_EQ(agent_fini_counter_, (size_t)0) << "Expect 0 agent finalization"; + + agent1.reset(); + EXPECT_EQ(agent_init_counter_, (size_t)1) << "Expect 1 agent initialization"; + EXPECT_EQ(agent_fini_counter_, (size_t)0) << "Expect 0 agent finalization"; + agent2.reset(); + EXPECT_EQ(agent_init_counter_, (size_t)1) << "Expect 1 agent initialization"; + EXPECT_EQ(agent_fini_counter_, (size_t)1) << "Expect 1 agent finalization"; + + // Create again after all previous agents are reset + status = tc::TritonRepoAgentManager::CreateAgent("global_agent", &agent1); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation" << status.AsString(); + EXPECT_EQ(agent_init_counter_, (size_t)2) << "Expect 2 agent initialization"; + EXPECT_EQ(agent_fini_counter_, (size_t)1) << "Expect 1 agent finalization"; + agent1.reset(); + EXPECT_EQ(agent_init_counter_, (size_t)2) << "Expect 2 agent initialization"; + EXPECT_EQ(agent_fini_counter_, (size_t)2) << "Expect 2 agent finalization"; +} + +class TritonRepoAgentModelTest : public ::testing::Test { + protected: + void SetUp() override + { + simple_config_.set_name("simple_config"); + + // Add a simple agent handle for convinence + tc::TritonRepoAgent::TritonRepoAgentModelActionFn_t ActionFn = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ActionType action_type) + -> TRITONSERVER_Error* { return nullptr; }; + auto agent_handle = MockSharedLibraryHandle(); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_ModelAction", reinterpret_cast(ActionFn)); + global_mock_agents.emplace("simple_agent_path", agent_handle); + + // Add a agent handle for logging actions of the model + tc::TritonRepoAgent::TritonRepoAgentModelInitFn_t LogInitFn = + [](TRITONREPOAGENT_Agent* agent, + TRITONREPOAGENT_AgentModel* model) -> TRITONSERVER_Error* { + auto lagent = reinterpret_cast(agent); + auto state = reinterpret_cast*>(lagent->State()); + if (state == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, "Agent state is not set"); + } + state->emplace_back("Model Initialized"); + return nullptr; + }; + tc::TritonRepoAgent::TritonRepoAgentModelFiniFn_t LogFiniFn = + [](TRITONREPOAGENT_Agent* agent, + TRITONREPOAGENT_AgentModel* model) -> TRITONSERVER_Error* { + auto lagent = reinterpret_cast(agent); + auto state = reinterpret_cast*>(lagent->State()); + if (state == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, "Agent state is not set"); + } + state->emplace_back("Model Finalized"); + return nullptr; + }; + tc::TritonRepoAgent::TritonRepoAgentModelActionFn_t LogActionFn = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ActionType action_type) + -> TRITONSERVER_Error* { + auto lagent = reinterpret_cast(agent); + auto state = reinterpret_cast*>(lagent->State()); + if (state == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, "Agent state is not set"); + } + state->emplace_back(tc::TRITONREPOAGENT_ActionTypeString(action_type)); + return nullptr; + }; + auto log_agent_handle = MockSharedLibraryHandle(); + log_agent_handle.AddEntryPoint( + "TRITONREPOAGENT_ModelInitialize", reinterpret_cast(LogInitFn)); + log_agent_handle.AddEntryPoint( + "TRITONREPOAGENT_ModelFinalize", reinterpret_cast(LogFiniFn)); + log_agent_handle.AddEntryPoint( + "TRITONREPOAGENT_ModelAction", reinterpret_cast(LogActionFn)); + global_mock_agents.emplace("log_agent_path", log_agent_handle); + } + void TearDown() override { global_mock_agents.clear(); } + + TRITONREPOAGENT_ArtifactType original_type_ = + TRITONREPOAGENT_ARTIFACT_FILESYSTEM; + const std::string original_location_ = "/original"; + inference::ModelConfig simple_config_; +}; + +TEST_F(TritonRepoAgentModelTest, Create) +{ + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = + tc::TritonRepoAgent::Create("agent", "simple_agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + + // Create model + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + EXPECT_EQ(model->Config().name(), simple_config_.name()) + << "Expect the model contains the same config as simple config"; +} + +TEST_F(TritonRepoAgentModelTest, CreateFailure) +{ + // Create agent to be associated with the model, whose model init function + // always returns error + tc::TritonRepoAgent::TritonRepoAgentModelInitFn_t InitFn = + [](TRITONREPOAGENT_Agent* agent, + TRITONREPOAGENT_AgentModel* model) -> TRITONSERVER_Error* { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, "Model initialization error"); + }; + tc::TritonRepoAgent::TritonRepoAgentModelActionFn_t ActionFn = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ActionType action_type) -> TRITONSERVER_Error* { + return nullptr; + }; + auto agent_handle = MockSharedLibraryHandle(); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_ModelInitialize", reinterpret_cast(InitFn)); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_ModelAction", reinterpret_cast(ActionFn)); + global_mock_agents.emplace("agent_path", agent_handle); + + std::shared_ptr agent; + auto status = tc::TritonRepoAgent::Create("agent", "agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + + // Create model + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_FALSE(status.IsOk()) << "Unexpect successful model creation"; + EXPECT_NE( + status.Message().find("Model initialization error"), std::string::npos) + << "Unexpect error message: '" << status.Message() + << "', expect 'Model initialization error...'"; +} + +TEST_F(TritonRepoAgentModelTest, Location) +{ + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = + tc::TritonRepoAgent::Create("agent", "simple_agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + + // Create model + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + TRITONREPOAGENT_ArtifactType type; + const char* location; + status = model->Location(&type, &location); + ASSERT_TRUE(status.IsOk()) << "Expect location is returned from Location()"; + EXPECT_EQ(type, original_type_) << "Expect returned original filesystem type"; + EXPECT_EQ(std::string(location), original_location_) + << "Expect returned original location"; +} + +TEST_F(TritonRepoAgentModelTest, SetLocationFailure) +{ + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = + tc::TritonRepoAgent::Create("agent", "simple_agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + + // Create model + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + TRITONREPOAGENT_ArtifactType type = TRITONREPOAGENT_ARTIFACT_FILESYSTEM; + const char* location = "/tmp"; + status = model->SetLocation(type, location); + ASSERT_FALSE(status.IsOk()) << "Expect error returned from SetLocation()"; + EXPECT_NE( + status.Message().find( + "location can only be updated during TRITONREPOAGENT_ACTION_LOAD, " + "current action type is not set"), + std::string::npos) + << "Unexpect error message: '" << status.Message() + << "', expect 'location can only be updated during " + "TRITONREPOAGENT_ACTION_LOAD, current action type is not set'"; +} + +TEST_F(TritonRepoAgentModelTest, SetLocation) +{ + static const TRITONREPOAGENT_ArtifactType new_type = + TRITONREPOAGENT_ARTIFACT_FILESYSTEM; + static const std::string new_location = "/new_location"; + + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = + tc::TritonRepoAgent::Create("agent", "simple_agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + + // Create model + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + // Advance the model lifecycle to be able to set location + status = model->InvokeAgent(TRITONREPOAGENT_ACTION_LOAD); + EXPECT_TRUE(status.IsOk()) + << "Expect successful agent invocation with TRITONREPOAGENT_ACTION_LOAD"; + status = model->SetLocation(new_type, new_location); + ASSERT_TRUE(status.IsOk()) + << "Expect successful SetLocation() after invoking agent with " + "TRITONREPOAGENT_ACTION_LOAD"; + TRITONREPOAGENT_ArtifactType type = original_type_; + const char* location = original_location_.c_str(); + status = model->Location(&type, &location); + ASSERT_TRUE(status.IsOk()) << "Expect location is returned from Location()"; + EXPECT_EQ(type, new_type) << "Expect returned filesystem type is " + << tc::TRITONREPOAGENT_ArtifactTypeString(new_type); + EXPECT_EQ(std::string(location), new_location) + << "Expect returned location is " << new_location; +} + +TEST_F(TritonRepoAgentModelTest, SetLocationWrongActionFailure) +{ + static const TRITONREPOAGENT_ArtifactType new_type = + TRITONREPOAGENT_ARTIFACT_FILESYSTEM; + static const std::string new_location = "/new_location"; + + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = + tc::TritonRepoAgent::Create("agent", "simple_agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + + // Create model + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + // Advance the model lifecycle to be able to set location + status = model->InvokeAgent(TRITONREPOAGENT_ACTION_LOAD); + EXPECT_TRUE(status.IsOk()) + << "Expect successful agent invocation with TRITONREPOAGENT_ACTION_LOAD"; + status = model->InvokeAgent(TRITONREPOAGENT_ACTION_LOAD_COMPLETE); + EXPECT_TRUE(status.IsOk()) << "Expect successful agent invocation with " + "TRITONREPOAGENT_ACTION_LOAD_COMPLETE"; + status = model->SetLocation(new_type, new_location); + ASSERT_FALSE(status.IsOk()) << "Expect error returned from SetLocation()"; + EXPECT_NE( + status.Message().find( + "location can only be updated during TRITONREPOAGENT_ACTION_LOAD, " + "current action type is TRITONREPOAGENT_ACTION_LOAD_COMPLETE"), + std::string::npos) + << "Unexpect error message: '" << status.Message() + << "', expect 'location can only be updated during " + "TRITONREPOAGENT_ACTION_LOAD, current action type is " + "TRITONREPOAGENT_ACTION_LOAD_COMPLETE'"; +} + +TEST_F(TritonRepoAgentModelTest, SetLocationViaAgent) +{ + static const TRITONREPOAGENT_ArtifactType new_type = + TRITONREPOAGENT_ARTIFACT_FILESYSTEM; + static const std::string new_location = "/new_location"; + // Create agent to be associated with the model + tc::TritonRepoAgent::TritonRepoAgentModelActionFn_t ActionFn = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ActionType action_type) -> TRITONSERVER_Error* { + auto lmodel = reinterpret_cast(model); + auto status = lmodel->SetLocation(new_type, new_location); + return reinterpret_cast( + TritonServerError::Create(status)); + }; + auto agent_handle = MockSharedLibraryHandle(); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_ModelAction", reinterpret_cast(ActionFn)); + global_mock_agents.emplace("set_location_agent_path", agent_handle); + std::shared_ptr agent; + auto status = + tc::TritonRepoAgent::Create("agent", "set_location_agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + + // Create model + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + // Advance the model lifecycle to be able to set location + status = model->InvokeAgent(TRITONREPOAGENT_ACTION_LOAD); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent invocation with TRITONREPOAGENT_ACTION_LOAD"; + TRITONREPOAGENT_ArtifactType type = original_type_; + const char* location = original_location_.c_str(); + status = model->Location(&type, &location); + ASSERT_TRUE(status.IsOk()) << "Expect location is returned from Location()"; + EXPECT_EQ(type, new_type) << "Expect returned filesystem type is " + << tc::TRITONREPOAGENT_ArtifactTypeString(new_type); + EXPECT_EQ(std::string(location), new_location) + << "Expect returned location is " << new_location; +} + +TEST_F(TritonRepoAgentModelTest, DeleteLocationBeforeAcquire) +{ + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = + tc::TritonRepoAgent::Create("agent", "simple_agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + + // Create model + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + + status = model->DeleteMutableLocation(); + ASSERT_FALSE(status.IsOk()) + << "Expect error returned from DeleteMutableLocation()"; + EXPECT_NE( + status.Message().find("No mutable location to be deleted"), + std::string::npos) + << "Unexpect error message: '" << status.Message() + << "', expect 'No mutable location to be deleted'"; +} + +TEST_F(TritonRepoAgentModelTest, AcquireLocalLocationAndDelete) +{ + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = + tc::TritonRepoAgent::Create("agent", "simple_agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + + // Create model + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + const char* acquired_location; + status = model->AcquireMutableLocation( + TRITONREPOAGENT_ARTIFACT_FILESYSTEM, &acquired_location); + ASSERT_TRUE(status.IsOk()) + << "Expect successful location acquisition: " << status.AsString(); + + // Check directory + bool is_dir = false; + status = tc::IsDirectory(acquired_location, &is_dir); + ASSERT_TRUE(status.IsOk()) + << "Expect location proprety can be checked: " << status.AsString(); + EXPECT_TRUE(is_dir) << "Expect a directory is returned as mutable location"; + tc::FileSystemType type = tc::FileSystemType::LOCAL; + status = tc::GetFileSystemType(acquired_location, &type); + ASSERT_TRUE(status.IsOk()) + << "Expect location filesystem type can be checked: " + << status.AsString(); + EXPECT_EQ(type, tc::FileSystemType::LOCAL) + << "Expect a local mutable location is acquired"; + + status = model->DeleteMutableLocation(); + ASSERT_TRUE(status.IsOk()) + << "Expect successful location deletion: " << status.AsString(); + // Check directory + bool exists = true; + status = tc::FileExists(acquired_location, &exists); + ASSERT_TRUE(status.IsOk()) + << "Expect location proprety can be checked: " << status.AsString(); + EXPECT_FALSE(exists) << "Expect the mutable location no longer exists"; +} + +TEST_F(TritonRepoAgentModelTest, AcquireLocalLocationTwice) +{ + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = + tc::TritonRepoAgent::Create("agent", "simple_agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + + // Create model + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + + const char* acquired_location; + status = model->AcquireMutableLocation( + TRITONREPOAGENT_ARTIFACT_FILESYSTEM, &acquired_location); + ASSERT_TRUE(status.IsOk()) + << "Expect successful location acquisition: " << status.AsString(); + + // Acquire the same type again + const char* second_acquired_location; + status = model->AcquireMutableLocation( + TRITONREPOAGENT_ARTIFACT_FILESYSTEM, &second_acquired_location); + ASSERT_TRUE(status.IsOk()) + << "Expect successful location acquisition: " << status.AsString(); + EXPECT_EQ( + std::string(acquired_location), std::string(second_acquired_location)) + << "Expect the same location is returned"; +} + +TEST_F(TritonRepoAgentModelTest, DeleteTwiceAfterAcquire) +{ + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = + tc::TritonRepoAgent::Create("agent", "simple_agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + + // Create model + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + const char* acquired_location; + status = model->AcquireMutableLocation( + TRITONREPOAGENT_ARTIFACT_FILESYSTEM, &acquired_location); + ASSERT_TRUE(status.IsOk()) + << "Expect successful location acquisition: " << status.AsString(); + + status = model->DeleteMutableLocation(); + ASSERT_TRUE(status.IsOk()) + << "Expect successful location deletion: " << status.AsString(); + status = model->DeleteMutableLocation(); + ASSERT_FALSE(status.IsOk()) + << "Expect error returned from DeleteMutableLocation()"; + EXPECT_NE( + status.Message().find("No mutable location to be deleted"), + std::string::npos) + << "Unexpect error message: '" << status.Message() + << "', expect 'No mutable location to be deleted'"; +} + +TEST_F(TritonRepoAgentModelTest, AcquireRemoteLocation) +{ + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = + tc::TritonRepoAgent::Create("agent", "simple_agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + + // Create model + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + + const char* acquired_location; + status = model->AcquireMutableLocation( + TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM, &acquired_location); + ASSERT_FALSE(status.IsOk()) + << "Expect error returned from AcquireMutableLocation()"; + const std::string search_msg = + "Unexpected artifact type, expects 'TRITONREPOAGENT_ARTIFACT_FILESYSTEM'"; + EXPECT_NE(status.Message().find(search_msg), std::string::npos) + << "Unexpect error message: '" << status.Message() << "', expect '" + << search_msg << "'"; +} + +TEST_F(TritonRepoAgentModelTest, AgentParameters) +{ + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = + tc::TritonRepoAgent::Create("agent", "simple_agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + + // Create model + tc::TritonRepoAgent::Parameters expected_params{{"key_a", "value_b"}, + {"key_b", "value_b"}}; + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + expected_params, &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + auto agent_params = model->AgentParameters(); + ASSERT_EQ(agent_params.size(), expected_params.size()); + for (size_t idx = 0; idx < agent_params.size(); ++idx) { + EXPECT_EQ(agent_params[idx].first, expected_params[idx].first); + EXPECT_EQ(agent_params[idx].second, expected_params[idx].second); + } +} + +TEST_F(TritonRepoAgentModelTest, State) +{ + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = + tc::TritonRepoAgent::Create("agent", "simple_agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + + // Create model + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + auto state = model->State(); + ASSERT_TRUE(state == nullptr) << "Expect state is not set"; + bool state_value = true; + model->SetState(reinterpret_cast(&state_value)); + state = model->State(); + ASSERT_TRUE(state != nullptr) << "Expect state is set"; + EXPECT_EQ(*reinterpret_cast(state), state_value) + << "Expect state value is true"; +} + +TEST_F(TritonRepoAgentModelTest, EmptyLifeCycle) +{ + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = tc::TritonRepoAgent::Create("agent", "log_agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + std::vector log; + agent->SetState(reinterpret_cast(&log)); + + // Create and destroy model + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + model.reset(); + + // Check log + ASSERT_EQ(log.size(), (size_t)2) + << "Expect 2 state of model lifecycle is logged, got " << log.size(); + EXPECT_EQ(log[0], "Model Initialized"); + EXPECT_EQ(log[1], "Model Finalized"); +} + +TEST_F(TritonRepoAgentModelTest, HalfLifeCycle) +{ + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = tc::TritonRepoAgent::Create("agent", "log_agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + std::vector log; + agent->SetState(reinterpret_cast(&log)); + + std::unique_ptr model; + // Create and destroy model in situations that a full lifecycle should run + std::vector> situations{ + {TRITONREPOAGENT_ACTION_LOAD}, + {TRITONREPOAGENT_ACTION_LOAD, TRITONREPOAGENT_ACTION_LOAD_FAIL}}; + std::vector expected_log{ + "Model Initialized", "TRITONREPOAGENT_ACTION_LOAD", + "TRITONREPOAGENT_ACTION_LOAD_FAIL", "Model Finalized"}; + for (const auto& situation : situations) { + log.clear(); + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + for (const auto action : situation) { + status = model->InvokeAgent(action); + EXPECT_TRUE(status.IsOk()) << "Expect successful agent invocation with " + << tc::TRITONREPOAGENT_ActionTypeString(action) + << ": " << status.AsString(); + } + model.reset(); + + // Check log + ASSERT_EQ(log.size(), expected_log.size()) + << "Expect " << expected_log.size() + << " state of model lifecycle is logged, got " << log.size(); + for (size_t i = 0; i < log.size(); ++i) { + EXPECT_EQ(log[i], expected_log[i]); + } + } +} + +TEST_F(TritonRepoAgentModelTest, FullLifeCycle) +{ + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = tc::TritonRepoAgent::Create("agent", "log_agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + std::vector log; + agent->SetState(reinterpret_cast(&log)); + + std::unique_ptr model; + // Create and destroy model in situations that a full lifecycle should run + std::vector> situations{ + {TRITONREPOAGENT_ACTION_LOAD, TRITONREPOAGENT_ACTION_LOAD_COMPLETE}, + {TRITONREPOAGENT_ACTION_LOAD, TRITONREPOAGENT_ACTION_LOAD_COMPLETE, + TRITONREPOAGENT_ACTION_UNLOAD}, + {TRITONREPOAGENT_ACTION_LOAD, TRITONREPOAGENT_ACTION_LOAD_COMPLETE, + TRITONREPOAGENT_ACTION_UNLOAD, TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE}}; + std::vector expected_log{ + "Model Initialized", + "TRITONREPOAGENT_ACTION_LOAD", + "TRITONREPOAGENT_ACTION_LOAD_COMPLETE", + "TRITONREPOAGENT_ACTION_UNLOAD", + "TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE", + "Model Finalized"}; + for (const auto& situation : situations) { + log.clear(); + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + for (const auto action : situation) { + status = model->InvokeAgent(action); + EXPECT_TRUE(status.IsOk()) << "Expect successful agent invocation with " + << tc::TRITONREPOAGENT_ActionTypeString(action) + << ": " << status.AsString(); + } + model.reset(); + + // Check log + ASSERT_EQ(log.size(), expected_log.size()) + << "Expect " << expected_log.size() + << " state of model lifecycle is logged, got " << log.size(); + for (size_t i = 0; i < log.size(); ++i) { + EXPECT_EQ(log[i], expected_log[i]); + } + } +} + +TEST_F(TritonRepoAgentModelTest, WrongLifeCycle) +{ + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = tc::TritonRepoAgent::Create("agent", "log_agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + std::vector log; + agent->SetState(reinterpret_cast(&log)); + + // Create model and run all action combinations + std::vector> valid_lifecycles{ + {TRITONREPOAGENT_ACTION_LOAD, TRITONREPOAGENT_ACTION_LOAD_FAIL}, + {TRITONREPOAGENT_ACTION_LOAD, TRITONREPOAGENT_ACTION_LOAD_COMPLETE, + TRITONREPOAGENT_ACTION_UNLOAD, TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE}}; + std::vector available_actions{ + TRITONREPOAGENT_ACTION_LOAD, TRITONREPOAGENT_ACTION_LOAD_FAIL, + TRITONREPOAGENT_ACTION_LOAD_COMPLETE, TRITONREPOAGENT_ACTION_UNLOAD, + TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE}; + std::map> + valid_actions{{TRITONREPOAGENT_ACTION_LOAD, + {TRITONREPOAGENT_ACTION_LOAD_FAIL, + TRITONREPOAGENT_ACTION_LOAD_COMPLETE}}, + {TRITONREPOAGENT_ACTION_LOAD_FAIL, {}}, + {TRITONREPOAGENT_ACTION_LOAD_COMPLETE, + {TRITONREPOAGENT_ACTION_UNLOAD}}, + {TRITONREPOAGENT_ACTION_UNLOAD, + {TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE}}, + {TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE, {}}}; + for (const auto& valid_lifecycle : valid_lifecycles) { + log.clear(); + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + for (size_t idx = 0; idx < valid_lifecycle.size(); ++idx) { + const auto next_lifecycle_action = valid_lifecycle[idx]; + // Handle the first action specially + if (idx == 0) { + for (const auto action : available_actions) { + if (action == valid_lifecycle[0]) { + continue; + } + status = model->InvokeAgent(action); + if (status.IsOk()) { + for (const auto& state_log : log) { + EXPECT_TRUE(false) << state_log; + } + } + ASSERT_FALSE(status.IsOk()) + << "Unexpect successful agent invocation with " + << tc::TRITONREPOAGENT_ActionTypeString(action); + } + status = model->InvokeAgent(valid_lifecycle[0]); + if (!status.IsOk()) { + for (const auto& state_log : log) { + EXPECT_TRUE(false) << state_log; + } + } + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent invocation with " + << tc::TRITONREPOAGENT_ActionTypeString(next_lifecycle_action) + << ": " << status.AsString(); + continue; + } + const auto& current_valid_actions = + valid_actions[valid_lifecycle[idx - 1]]; + for (const auto action : available_actions) { + if (current_valid_actions.find(action) != current_valid_actions.end()) { + continue; + } + status = model->InvokeAgent(action); + if (status.IsOk()) { + for (const auto& state_log : log) { + EXPECT_TRUE(false) << state_log; + } + } + ASSERT_FALSE(status.IsOk()) + << "Unexpect successful agent invocation with " + << tc::TRITONREPOAGENT_ActionTypeString(action); + } + status = model->InvokeAgent(next_lifecycle_action); + if (!status.IsOk()) { + for (const auto& state_log : log) { + EXPECT_TRUE(false) << state_log; + } + } + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent invocation with " + << tc::TRITONREPOAGENT_ActionTypeString(next_lifecycle_action) << ": " + << status.AsString(); + } + } +} + +class TritonRepoAgentAPITest : public ::testing::Test { + public: + static std::function agent_init_fn_; + static std::function agent_fini_fn_; + static std::function + model_init_fn_; + static std::function + model_action_fn_; + static std::function + model_fini_fn_; + + protected: + void SetUp() override + { + simple_config_.set_name("simple_config"); + // Add a agent handle for flexible testing + tc::TritonRepoAgent::TritonRepoAgentInitFn_t AgentInitFn = + [](TRITONREPOAGENT_Agent* agent) -> TRITONSERVER_Error* { + if (agent_init_fn_ != nullptr) { + agent_init_fn_(agent); + } + return nullptr; + }; + tc::TritonRepoAgent::TritonRepoAgentFiniFn_t AgentFiniFn = + [](TRITONREPOAGENT_Agent* agent) -> TRITONSERVER_Error* { + if (agent_fini_fn_ != nullptr) { + agent_fini_fn_(agent); + } + return nullptr; + }; + tc::TritonRepoAgent::TritonRepoAgentModelInitFn_t ModelInitFn = + [](TRITONREPOAGENT_Agent* agent, + TRITONREPOAGENT_AgentModel* model) -> TRITONSERVER_Error* { + if (model_init_fn_ != nullptr) { + model_init_fn_(agent, model); + } + return nullptr; + }; + tc::TritonRepoAgent::TritonRepoAgentModelActionFn_t ModelActionFn = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ActionType action_type) + -> TRITONSERVER_Error* { + if (model_action_fn_ != nullptr) { + model_action_fn_(agent, model); + } + return nullptr; + }; + tc::TritonRepoAgent::TritonRepoAgentModelFiniFn_t ModelFiniFn = + [](TRITONREPOAGENT_Agent* agent, + TRITONREPOAGENT_AgentModel* model) -> TRITONSERVER_Error* { + if (model_fini_fn_ != nullptr) { + model_fini_fn_(agent, model); + } + return nullptr; + }; + auto agent_handle = MockSharedLibraryHandle(); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_Initialize", reinterpret_cast(AgentInitFn)); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_Finalize", reinterpret_cast(AgentFiniFn)); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_ModelInitialize", + reinterpret_cast(ModelInitFn)); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_ModelAction", reinterpret_cast(ModelActionFn)); + agent_handle.AddEntryPoint( + "TRITONREPOAGENT_ModelFinalize", reinterpret_cast(ModelFiniFn)); + global_mock_agents.emplace("agent_path", agent_handle); + } + void TearDown() override + { + global_mock_agents.clear(); + agent_init_fn_ = nullptr; + agent_fini_fn_ = nullptr; + model_init_fn_ = nullptr; + model_action_fn_ = nullptr; + model_fini_fn_ = nullptr; + } + + TRITONREPOAGENT_ArtifactType original_type_ = + TRITONREPOAGENT_ARTIFACT_FILESYSTEM; + const std::string original_location_ = "/original"; + inference::ModelConfig simple_config_; + + std::vector> valid_lifecycles_{ + {TRITONREPOAGENT_ACTION_LOAD, TRITONREPOAGENT_ACTION_LOAD_FAIL}, + {TRITONREPOAGENT_ACTION_LOAD, TRITONREPOAGENT_ACTION_LOAD_COMPLETE, + TRITONREPOAGENT_ACTION_UNLOAD, TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE}}; +}; + +std::function + TritonRepoAgentAPITest::agent_init_fn_ = nullptr; +std::function + TritonRepoAgentAPITest::agent_fini_fn_ = nullptr; +std::function + TritonRepoAgentAPITest::model_init_fn_ = nullptr; +std::function + TritonRepoAgentAPITest::model_action_fn_ = nullptr; +std::function + TritonRepoAgentAPITest::model_fini_fn_ = nullptr; + +TEST_F(TritonRepoAgentAPITest, TRITONREPOAGENT_ApiVersion) +{ + agent_init_fn_ = + [](TRITONREPOAGENT_Agent* agent) { + uint32_t major = 0; + uint32_t minor = 0; + auto err = TRITONREPOAGENT_ApiVersion(&major, &minor); + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONREPOAGENT_ApiVersion() invokation: " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_EQ(major, (uint32_t)TRITONREPOAGENT_API_VERSION_MAJOR) + << "Unexpected major veresion"; + EXPECT_EQ(minor, (uint32_t)TRITONREPOAGENT_API_VERSION_MINOR) + << "Unexpected major veresion"; + } + }; + agent_fini_fn_ = agent_init_fn_; + model_init_fn_ = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model) { + uint32_t major = 0; + uint32_t minor = 0; + auto err = TRITONREPOAGENT_ApiVersion(&major, &minor); + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONREPOAGENT_ApiVersion() invokation: " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_EQ(major, (uint32_t)TRITONREPOAGENT_API_VERSION_MAJOR) + << "Unexpected major veresion"; + EXPECT_EQ(minor, (uint32_t)TRITONREPOAGENT_API_VERSION_MINOR) + << "Unexpected major veresion"; + } + }; + model_action_fn_ = model_init_fn_; + model_fini_fn_ = model_init_fn_; + + const auto lifecycles = valid_lifecycles_; + for (const auto& lifecycle : lifecycles) { + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = tc::TritonRepoAgent::Create("agent", "agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + for (const auto action : lifecycle) { + status = model->InvokeAgent(action); + ASSERT_TRUE(status.IsOk()) << "Expect successful agent invocation with " + << tc::TRITONREPOAGENT_ActionTypeString(action) + << ": " << status.AsString(); + } + } +} + +TEST_F(TritonRepoAgentAPITest, TRITONREPOAGENT_ModelRepositoryLocation) +{ + model_init_fn_ = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model) { + TRITONREPOAGENT_ArtifactType artifact_type = + TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM; + const char* location = nullptr; + auto err = TRITONREPOAGENT_ModelRepositoryLocation( + agent, model, &artifact_type, &location); + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONREPOAGENT_ModelRepositoryLocation(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_EQ(artifact_type, TRITONREPOAGENT_ARTIFACT_FILESYSTEM) + << "Unexpected artifact type"; + EXPECT_EQ(std::string(location), "/original") + << "Unexpected location"; + } + }; + model_action_fn_ = model_init_fn_; + model_fini_fn_ = model_init_fn_; + + const auto lifecycles = valid_lifecycles_; + for (const auto& lifecycle : lifecycles) { + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = tc::TritonRepoAgent::Create("agent", "agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + for (const auto action : lifecycle) { + status = model->InvokeAgent(action); + ASSERT_TRUE(status.IsOk()) << "Expect successful agent invocation with " + << tc::TRITONREPOAGENT_ActionTypeString(action) + << ": " << status.AsString(); + } + } +} + +TEST_F( + TritonRepoAgentAPITest, + TRITONREPOAGENT_ModelRepositoryLocationAcquireRemote) +{ + model_init_fn_ = [](TRITONREPOAGENT_Agent* agent, + TRITONREPOAGENT_AgentModel* model) { + TRITONREPOAGENT_ArtifactType artifact_type = + TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM; + const char* location = nullptr; + auto err = TRITONREPOAGENT_ModelRepositoryLocationAcquire( + agent, model, artifact_type, &location); + if (err != nullptr) { + const std::string err_msg = TRITONSERVER_ErrorMessage(err); + const std::string search_msg = + "Unexpected artifact type, expects " + "'TRITONREPOAGENT_ARTIFACT_FILESYSTEM'"; + EXPECT_NE(err_msg.find(search_msg), std::string::npos) + << "Unexpect error message: '" << err_msg << "', expect '" + << search_msg << "'"; + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_TRUE(false) << "Expect error returned from " + "TRITONREPOAGENT_ModelRepositoryLocationAcquire()"; + } + }; + model_action_fn_ = model_init_fn_; + model_fini_fn_ = model_init_fn_; + + const auto lifecycles = valid_lifecycles_; + for (const auto& lifecycle : lifecycles) { + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = tc::TritonRepoAgent::Create("agent", "agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + for (const auto action : lifecycle) { + status = model->InvokeAgent(action); + ASSERT_TRUE(status.IsOk()) << "Expect successful agent invocation with " + << tc::TRITONREPOAGENT_ActionTypeString(action) + << ": " << status.AsString(); + } + } +} + +TEST_F(TritonRepoAgentAPITest, TRITONREPOAGENT_ModelRepositoryLocationAcquire) +{ + model_init_fn_ = [](TRITONREPOAGENT_Agent* agent, + TRITONREPOAGENT_AgentModel* model) { + // Acquire, acquire (same), release + TRITONREPOAGENT_ArtifactType artifact_type = + TRITONREPOAGENT_ARTIFACT_FILESYSTEM; + const char* location = nullptr; + auto err = TRITONREPOAGENT_ModelRepositoryLocationAcquire( + agent, model, artifact_type, &location); + if (err != nullptr) { + EXPECT_TRUE(false) << "Expect successful " + "TRITONREPOAGENT_ModelRepositoryLocationAcquire(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } + + std::string acquired_location = location; + err = TRITONREPOAGENT_ModelRepositoryLocationAcquire( + agent, model, artifact_type, &location); + if (err != nullptr) { + EXPECT_TRUE(false) << "Expect successful " + "TRITONREPOAGENT_ModelRepositoryLocationAcquire(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_EQ(acquired_location, std::string(location)) + << "Expect the same location is acquired"; + } + }; + model_action_fn_ = model_init_fn_; + model_fini_fn_ = model_init_fn_; + + const auto lifecycles = valid_lifecycles_; + for (const auto& lifecycle : lifecycles) { + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = tc::TritonRepoAgent::Create("agent", "agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + for (const auto action : lifecycle) { + status = model->InvokeAgent(action); + ASSERT_TRUE(status.IsOk()) << "Expect successful agent invocation with " + << tc::TRITONREPOAGENT_ActionTypeString(action) + << ": " << status.AsString(); + } + } +} + +TEST_F(TritonRepoAgentAPITest, TRITONREPOAGENT_ModelRepositoryLocationRelease) +{ + model_init_fn_ = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model) { + // relase (fail), acquire, release + const char* location = "nonexisting_location"; + auto err = TRITONREPOAGENT_ModelRepositoryLocationRelease( + agent, model, location); + if (err != nullptr) { + const std::string search_msg = "No mutable location to be deleted"; + const std::string err_msg = TRITONSERVER_ErrorMessage(err); + EXPECT_NE(err_msg.find(search_msg), std::string::npos) + << "Unexpect error message: '" << err_msg << "', expect '" + << search_msg << "'"; + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_TRUE(false) + << "Expect error returned from " + "TRITONREPOAGENT_ModelRepositoryLocationRelease()"; + } + + TRITONREPOAGENT_ArtifactType artifact_type = + TRITONREPOAGENT_ARTIFACT_FILESYSTEM; + err = TRITONREPOAGENT_ModelRepositoryLocationAcquire( + agent, model, artifact_type, &location); + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONREPOAGENT_ModelRepositoryLocation(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } + + err = TRITONREPOAGENT_ModelRepositoryLocationRelease( + agent, model, location); + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONREPOAGENT_ModelRepositoryLocation(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } + }; + model_action_fn_ = model_init_fn_; + model_fini_fn_ = model_init_fn_; + + const auto lifecycles = valid_lifecycles_; + for (const auto& lifecycle : lifecycles) { + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = tc::TritonRepoAgent::Create("agent", "agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + for (const auto action : lifecycle) { + status = model->InvokeAgent(action); + ASSERT_TRUE(status.IsOk()) << "Expect successful agent invocation with " + << tc::TRITONREPOAGENT_ActionTypeString(action) + << ": " << status.AsString(); + } + } +} + +TEST_F(TritonRepoAgentAPITest, TRITONREPOAGENT_ModelRepositoryUpdate) +{ + static std::string current_location = original_location_; + static TRITONREPOAGENT_ArtifactType current_type = + TRITONREPOAGENT_ARTIFACT_FILESYSTEM; + model_init_fn_ = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model) { + std::string new_location = current_location + "_new"; + TRITONREPOAGENT_ArtifactType artifact_type = + TRITONREPOAGENT_ARTIFACT_FILESYSTEM; + const char* location = new_location.c_str(); + auto err = TRITONREPOAGENT_ModelRepositoryUpdate( + agent, model, artifact_type, location); + if (err != nullptr) { + const std::string search_msg = + "location can only be updated during TRITONREPOAGENT_ACTION_LOAD"; + const std::string err_msg = TRITONSERVER_ErrorMessage(err); + EXPECT_NE(err_msg.find(search_msg), std::string::npos) + << "Unexpect error message: '" << err_msg << "', expect '" + << search_msg << "...'"; + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_TRUE(false) << "Expect error returned from " + "TRITONREPOAGENT_ModelRepositoryUpdate()"; + } + + // Check location shouldn't be changed + err = TRITONREPOAGENT_ModelRepositoryLocation( + agent, model, &artifact_type, &location); + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONREPOAGENT_ModelRepositoryLocation(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_EQ(artifact_type, current_type) << "Unexpected artifact type"; + EXPECT_EQ(std::string(location), current_location) + << "Unexpected location"; + } + }; + model_action_fn_ = model_init_fn_; + model_fini_fn_ = model_init_fn_; + + // Overriding the model action function in agent handle because the action + // type needs to be checked here + tc::TritonRepoAgent::TritonRepoAgentModelActionFn_t ModelActionFn = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ActionType action_type) -> TRITONSERVER_Error* { + std::string new_location = current_location + "_new"; + TRITONREPOAGENT_ArtifactType artifact_type = + TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM; + const char* location = new_location.c_str(); + auto err = TRITONREPOAGENT_ModelRepositoryUpdate( + agent, model, artifact_type, location); + if (action_type == TRITONREPOAGENT_ACTION_LOAD) { + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONREPOAGENT_ModelRepositoryUpdate(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + current_location = new_location; + current_type = artifact_type; + } + } else { + if (err != nullptr) { + const std::string search_msg = + "location can only be updated during TRITONREPOAGENT_ACTION_LOAD"; + const std::string err_msg = TRITONSERVER_ErrorMessage(err); + EXPECT_NE(err_msg.find(search_msg), std::string::npos) + << "Unexpect error message: '" << err_msg << "', expect '" + << search_msg << "...'"; + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_TRUE(false) << "Expect error returned from " + "TRITONREPOAGENT_ModelRepositoryUpdate()"; + } + } + + // Check location + err = TRITONREPOAGENT_ModelRepositoryLocation( + agent, model, &artifact_type, &location); + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONREPOAGENT_ModelRepositoryLocation(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_EQ(artifact_type, current_type) << "Unexpected artifact type"; + EXPECT_EQ(std::string(location), current_location) + << "Unexpected location"; + } + return nullptr; + }; + global_mock_agents["agent_path"].AddEntryPoint( + "TRITONREPOAGENT_ModelAction", reinterpret_cast(ModelActionFn)); + + const auto lifecycles = valid_lifecycles_; + for (const auto& lifecycle : lifecycles) { + // Reset location and type + current_location = original_location_; + current_type = TRITONREPOAGENT_ARTIFACT_FILESYSTEM; + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = tc::TritonRepoAgent::Create("agent", "agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, current_location, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + for (const auto action : lifecycle) { + status = model->InvokeAgent(action); + ASSERT_TRUE(status.IsOk()) << "Expect successful agent invocation with " + << tc::TRITONREPOAGENT_ActionTypeString(action) + << ": " << status.AsString(); + } + } +} + +TEST_F(TritonRepoAgentAPITest, TRITONREPOAGENT_ModelParameter) +{ + static tc::TritonRepoAgent::Parameters expected_params{{"key_a", "value_a"}, + {"key_b", "value_b"}}; + model_init_fn_ = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model) { + uint32_t count; + auto err = TRITONREPOAGENT_ModelParameterCount(agent, model, &count); + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONREPOAGENT_ModelParameterCount(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_EQ(count, expected_params.size()); + } + + const char* parameter_name = nullptr; + const char* parameter_value = nullptr; + for (size_t idx = 0; idx < count; ++idx) { + err = TRITONREPOAGENT_ModelParameter( + agent, model, idx, ¶meter_name, ¶meter_value); + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONREPOAGENT_ModelParameter(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_EQ(std::string(parameter_name), expected_params[idx].first); + EXPECT_EQ( + std::string(parameter_value), expected_params[idx].second); + } + } + // out of range + err = TRITONREPOAGENT_ModelParameter( + agent, model, count, ¶meter_name, ¶meter_value); + if (err != nullptr) { + const std::string search_msg = + "index out of range for model parameters"; + const std::string err_msg = TRITONSERVER_ErrorMessage(err); + EXPECT_NE(err_msg.find(search_msg), std::string::npos) + << "Unexpect error message: '" << err_msg << "', expect '" + << search_msg << "...'"; + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_TRUE(false) + << "Expect error returned from TRITONREPOAGENT_ModelParameter()"; + } + }; + model_action_fn_ = model_init_fn_; + model_fini_fn_ = model_init_fn_; + + const auto lifecycles = valid_lifecycles_; + for (const auto& lifecycle : lifecycles) { + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = tc::TritonRepoAgent::Create("agent", "agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + expected_params, &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + for (const auto action : lifecycle) { + status = model->InvokeAgent(action); + ASSERT_TRUE(status.IsOk()) << "Expect successful agent invocation with " + << tc::TRITONREPOAGENT_ActionTypeString(action) + << ": " << status.AsString(); + } + } +} + +TEST_F(TritonRepoAgentAPITest, TRITONREPOAGENT_ModelConfig) +{ + model_init_fn_ = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model) { + TRITONSERVER_Message* config = nullptr; + auto err = TRITONREPOAGENT_ModelConfig(agent, model, 1, &config); + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONREPOAGENT_ModelConfig(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } + const char* base = nullptr; + size_t byte_size = 0; + err = TRITONSERVER_MessageSerializeToJson(config, &base, &byte_size); + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONSERVER_MessageSerializeToJson(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + const std::string search_msg = "simple_config"; + const std::string serialized_config(base, byte_size); + EXPECT_NE(serialized_config.find(search_msg), std::string::npos) + << "Expect finding '" << search_msg + << "' in returned config: " << serialized_config; + } + + // unsupport version + err = TRITONREPOAGENT_ModelConfig(agent, model, 2, &config); + if (err != nullptr) { + const std::string search_msg = + "model configuration version 2 not supported, supported versions " + "are: 1"; + const std::string err_msg = TRITONSERVER_ErrorMessage(err); + EXPECT_NE(err_msg.find(search_msg), std::string::npos) + << "Unexpect error message: '" << err_msg << "', expect '" + << search_msg << "...'"; + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_TRUE(false) + << "Expect error returned from TRITONREPOAGENT_ModelConfig()"; + } + }; + model_action_fn_ = model_init_fn_; + model_fini_fn_ = model_init_fn_; + + const auto lifecycles = valid_lifecycles_; + for (const auto& lifecycle : lifecycles) { + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = tc::TritonRepoAgent::Create("agent", "agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + for (const auto action : lifecycle) { + status = model->InvokeAgent(action); + ASSERT_TRUE(status.IsOk()) << "Expect successful agent invocation with " + << tc::TRITONREPOAGENT_ActionTypeString(action) + << ": " << status.AsString(); + } + } +} + +TEST_F(TritonRepoAgentAPITest, TRITONREPOAGENT_ModelState) +{ + model_init_fn_ = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model) { + size_t* state = nullptr; + auto err = + TRITONREPOAGENT_ModelState(model, reinterpret_cast(&state)); + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONREPOAGENT_ModelState(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_TRUE(state == nullptr) << "Expect state is not set"; + } + state = new size_t(0); + err = TRITONREPOAGENT_ModelSetState( + model, reinterpret_cast(state)); + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONREPOAGENT_ModelSetState(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + delete state; + } + }; + model_fini_fn_ = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model) { + size_t* state = nullptr; + auto err = + TRITONREPOAGENT_ModelState(model, reinterpret_cast(&state)); + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONREPOAGENT_ModelState(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_TRUE(state != nullptr) << "Expect state is set"; + EXPECT_EQ(*state, size_t(0)); + } + + // Sanity check that set state works elsewhere + size_t* new_state = new size_t(*state); + delete state; + err = TRITONREPOAGENT_ModelSetState( + model, reinterpret_cast(new_state)); + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONREPOAGENT_ModelSetState(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } + + // Delete state before end of model lifecycle + delete new_state; + }; + // Overriding the model action function in agent handle because the action + // type needs to be checked here + tc::TritonRepoAgent::TritonRepoAgentModelActionFn_t ModelActionFn = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ActionType action_type) -> TRITONSERVER_Error* { + size_t* state = nullptr; + auto err = + TRITONREPOAGENT_ModelState(model, reinterpret_cast(&state)); + if (err != nullptr) { + EXPECT_TRUE(false) << "Expect successful TRITONREPOAGENT_ModelState(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } + EXPECT_TRUE(state != nullptr) << "Expect state is set"; + switch (action_type) { + case TRITONREPOAGENT_ACTION_LOAD: { + EXPECT_EQ(*state, size_t(0)); + ++*state; + break; + } + case TRITONREPOAGENT_ACTION_LOAD_COMPLETE: { + EXPECT_EQ(*state, size_t(1)); + ++*state; + break; + } + case TRITONREPOAGENT_ACTION_LOAD_FAIL: { + EXPECT_EQ(*state, size_t(1)); + --*state; + break; + } + case TRITONREPOAGENT_ACTION_UNLOAD: { + EXPECT_EQ(*state, size_t(2)); + --*state; + break; + } + case TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE: { + EXPECT_EQ(*state, size_t(1)); + --*state; + break; + } + } + + // Sanity check that set state works elsewhere + size_t* new_state = new size_t(*state); + delete state; + err = TRITONREPOAGENT_ModelSetState( + model, reinterpret_cast(new_state)); + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONREPOAGENT_ModelSetState(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + delete new_state; + } + return nullptr; + }; + global_mock_agents["agent_path"].AddEntryPoint( + "TRITONREPOAGENT_ModelAction", reinterpret_cast(ModelActionFn)); + + + const auto lifecycles = valid_lifecycles_; + for (const auto& lifecycle : lifecycles) { + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = tc::TritonRepoAgent::Create("agent", "agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + std::unique_ptr model; + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + for (const auto action : lifecycle) { + status = model->InvokeAgent(action); + ASSERT_TRUE(status.IsOk()) << "Expect successful agent invocation with " + << tc::TRITONREPOAGENT_ActionTypeString(action) + << ": " << status.AsString(); + } + } +} + +TEST_F(TritonRepoAgentAPITest, TRITONREPOAGENT_AgentState) +{ + // Two models share one agent, check if agent state is properly shared + agent_init_fn_ = [](TRITONREPOAGENT_Agent* agent) { + size_t* state = nullptr; + auto err = TRITONREPOAGENT_State(agent, reinterpret_cast(&state)); + if (err != nullptr) { + EXPECT_TRUE(false) << "Expect successful TRITONREPOAGENT_State(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_TRUE(state == nullptr) << "Expect state is not set"; + } + state = new size_t(0); + err = TRITONREPOAGENT_SetState(agent, reinterpret_cast(state)); + if (err != nullptr) { + EXPECT_TRUE(false) << "Expect successful TRITONREPOAGENT_SetState(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + delete state; + } + }; + agent_fini_fn_ = [](TRITONREPOAGENT_Agent* agent) { + size_t* state = nullptr; + auto err = TRITONREPOAGENT_State(agent, reinterpret_cast(&state)); + if (err != nullptr) { + EXPECT_TRUE(false) << "Expect successful TRITONREPOAGENT_State(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_TRUE(state != nullptr) << "Expect state is set"; + EXPECT_EQ(*state, size_t(0)); + } + + // Sanity check that set state works elsewhere + size_t* new_state = new size_t(*state); + delete state; + err = TRITONREPOAGENT_SetState(agent, reinterpret_cast(new_state)); + if (err != nullptr) { + EXPECT_TRUE(false) << "Expect successful TRITONREPOAGENT_SetState(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } + + // Delete state before end of agent lifecycle + delete new_state; + }; + model_init_fn_ = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model) { + size_t* state = nullptr; + auto err = + TRITONREPOAGENT_State(agent, reinterpret_cast(&state)); + if (err != nullptr) { + EXPECT_TRUE(false) << "Expect successful TRITONREPOAGENT_State(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_TRUE(state != nullptr) << "Expect state is set"; + } + + // Agent state maybe 0 or 1 depending on the order of model lifecycle, + // record that in model state to keep track of the order + if ((*state == 0) || (*state == 1)) { + size_t* model_state = new size_t(*state); + err = TRITONREPOAGENT_ModelSetState( + model, reinterpret_cast(model_state)); + if (err != nullptr) { + EXPECT_TRUE(false) + << "Expect successful TRITONREPOAGENT_ModelSetState(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } + } else { + EXPECT_TRUE(false) << "Expect agent state is either 0 or 1"; + } + + // Sanity check that set state works elsewhere + ++*state; + size_t* new_state = new size_t(*state); + delete state; + err = + TRITONREPOAGENT_SetState(agent, reinterpret_cast(new_state)); + if (err != nullptr) { + EXPECT_TRUE(false) << "Expect successful TRITONREPOAGENT_SetState(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + delete new_state; + } + }; + model_fini_fn_ = [](TRITONREPOAGENT_Agent* agent, + TRITONREPOAGENT_AgentModel* model) { + size_t* model_state = nullptr; + auto err = TRITONREPOAGENT_ModelState( + model, reinterpret_cast(&model_state)); + if (err != nullptr) { + EXPECT_TRUE(false) << "Expect successful TRITONREPOAGENT_ModelState(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_TRUE(model_state != nullptr) << "Expect state is set"; + } + + size_t* state = nullptr; + err = TRITONREPOAGENT_State(agent, reinterpret_cast(&state)); + if (err != nullptr) { + EXPECT_TRUE(false) << "Expect successful TRITONREPOAGENT_State(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_TRUE(state != nullptr) << "Expect state is set"; + EXPECT_EQ(*state, size_t(2) - *model_state); + } + + // Sanity check that set state works elsewhere + --*state; + size_t* new_state = new size_t(*state); + delete state; + err = TRITONREPOAGENT_SetState(agent, reinterpret_cast(new_state)); + if (err != nullptr) { + EXPECT_TRUE(false) << "Expect successful TRITONREPOAGENT_SetState(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + delete new_state; + } + + // Delete state before end of model lifecycle + delete model_state; + }; + // Overriding the model action function in agent handle because the action + // type needs to be checked here + tc::TritonRepoAgent::TritonRepoAgentModelActionFn_t ModelActionFn = + [](TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model, + const TRITONREPOAGENT_ActionType action_type) -> TRITONSERVER_Error* { + size_t* model_state = nullptr; + auto err = TRITONREPOAGENT_ModelState( + model, reinterpret_cast(&model_state)); + if (err != nullptr) { + EXPECT_TRUE(false) << "Expect successful TRITONREPOAGENT_ModelState(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } else { + EXPECT_TRUE(model_state != nullptr) << "Expect state is set"; + } + + size_t* state = nullptr; + err = TRITONREPOAGENT_State(agent, reinterpret_cast(&state)); + if (err != nullptr) { + EXPECT_TRUE(false) << "Expect successful TRITONREPOAGENT_State(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } + EXPECT_TRUE(state != nullptr) << "Expect state is set"; + switch (action_type) { + case TRITONREPOAGENT_ACTION_LOAD: { + EXPECT_EQ(*state, size_t(2) + *model_state); + ++*state; + break; + } + case TRITONREPOAGENT_ACTION_LOAD_COMPLETE: { + EXPECT_EQ(*state, size_t(4) + *model_state); + ++*state; + break; + } + case TRITONREPOAGENT_ACTION_LOAD_FAIL: { + EXPECT_EQ(*state, size_t(4) - *model_state); + --*state; + break; + } + case TRITONREPOAGENT_ACTION_UNLOAD: { + EXPECT_EQ(*state, size_t(6) - *model_state); + --*state; + break; + } + case TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE: { + EXPECT_EQ(*state, size_t(4) - *model_state); + --*state; + break; + } + } + + // Sanity check that set state works elsewhere + size_t* new_state = new size_t(*state); + delete state; + err = TRITONREPOAGENT_SetState(agent, reinterpret_cast(new_state)); + if (err != nullptr) { + EXPECT_TRUE(false) << "Expect successful TRITONREPOAGENT_SetState(): " + << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + delete new_state; + } + return nullptr; + }; + global_mock_agents["agent_path"].AddEntryPoint( + "TRITONREPOAGENT_ModelAction", reinterpret_cast(ModelActionFn)); + + + const auto lifecycles = valid_lifecycles_; + for (const auto& lifecycle : lifecycles) { + // Create agent to be associated with the model + std::shared_ptr agent; + auto status = tc::TritonRepoAgent::Create("agent", "agent_path", &agent); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent creation: " << status.AsString(); + std::vector> models(2); + for (auto& model : models) { + status = tc::TritonRepoAgentModel::Create( + original_type_, original_location_, simple_config_, agent, + tc::TritonRepoAgent::Parameters(), &model); + ASSERT_TRUE(status.IsOk()) + << "Expect successful model creation: " << status.AsString(); + } + for (const auto action : lifecycle) { + for (auto& model : models) { + status = model->InvokeAgent(action); + ASSERT_TRUE(status.IsOk()) + << "Expect successful agent invocation with " + << tc::TRITONREPOAGENT_ActionTypeString(action) << ": " + << status.AsString(); + } + } + } +} + +} // namespace + +int +main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/3rdparty/core-r22.12/src/test/response_cache_test.cc b/3rdparty/core-r22.12/src/test/response_cache_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4662f07b440f54c59d197f10c1f81ce8a294d443 --- /dev/null +++ b/3rdparty/core-r22.12/src/test/response_cache_test.cc @@ -0,0 +1,981 @@ +// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#include "gtest/gtest.h" + +#include +#include "memory.h" +#include "response_cache.h" +#include "triton/common/logging.h" + +namespace tc = triton::core; + +/* Mock classes for Unit Testing */ +namespace triton { namespace core { + +// +// InferenceResponseFactory +// +Status +InferenceResponseFactory::CreateResponse( + std::unique_ptr* response) const +{ + response->reset(new InferenceResponse( + model_, id_, allocator_, alloc_userp_, response_fn_, response_userp_, + response_delegator_)); + + return Status::Success; +} + +// +// InferenceRequest +// +InferenceRequest::InferenceRequest( + Model* model, const int64_t requested_model_version) + : needs_normalization_(true), model_raw_(model), + requested_model_version_(requested_model_version), flags_(0), + correlation_id_(0), batch_size_(0), timeout_us_(0), collect_stats_(true) +{ + // Unit test doesn't need actual response factory logic + // or other priority/request_counting logic, it just needs + // a non-null reponse factory object. + response_factory_.reset(new InferenceResponseFactory()); +} + +InferenceRequest::Input::Input( + const std::string& name, const inference::DataType datatype, + const int64_t* shape, const uint64_t dim_count) + : name_(name), datatype_(datatype), + original_shape_(shape, shape + dim_count), is_shape_tensor_(false), + data_(new MemoryReference), has_host_policy_specific_data_(false) +{ +} + +// Use const global var as locals can't be returned in ModelName(), +// and we don't care about the model for the unit test +const std::string MODEL = "model"; + +const std::string& +InferenceRequest::ModelName() const +{ + return MODEL; +} + +int64_t +InferenceRequest::ActualModelVersion() const +{ + // Not using model in unit test mock + return requested_model_version_; +} + +Status +InferenceRequest::PrepareForInference() +{ + // Remove override inputs as those are added during any previous + // inference execution. + inputs_.clear(); + override_inputs_.clear(); + + // Initially show the actual inputs to be only the original + // inputs. If overrides are added later they will be added to + // 'inputs_'. + for (auto& pr : original_inputs_) { + inputs_.emplace(std::make_pair(pr.first, std::addressof(pr.second))); + } + + // Clear the timestamps + queue_start_ns_ = 0; +#ifdef TRITON_ENABLE_STATS + request_start_ns_ = 0; +#endif // TRITON_ENABLE_STATS + + return Status::Success; +} + +Status +InferenceRequest::Input::DataBuffer( + const size_t idx, const void** base, size_t* byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id) const +{ + *base = data_->BufferAt(idx, byte_size, memory_type, memory_type_id); + + return Status::Success; +} + +Status +InferenceRequest::AddOriginalInput( + const std::string& name, const inference::DataType datatype, + const int64_t* shape, const uint64_t dim_count, + InferenceRequest::Input** input) +{ + const auto& pr = original_inputs_.emplace( + std::piecewise_construct, std::forward_as_tuple(name), + std::forward_as_tuple(name, datatype, shape, dim_count)); + if (!pr.second) { + return Status( + Status::Code::INVALID_ARG, + "input '" + name + "' already exists in request"); + } + + if (input != nullptr) { + *input = std::addressof(pr.first->second); + } + + needs_normalization_ = true; + return Status::Success; +} + +Status +InferenceRequest::AddOriginalInput( + const std::string& name, const inference::DataType datatype, + const std::vector& shape, InferenceRequest::Input** input) +{ + return AddOriginalInput(name, datatype, &shape[0], shape.size(), input); +} + +Status +InferenceRequest::Input::AppendData( + const void* base, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id) +{ + if (byte_size > 0) { + std::static_pointer_cast(data_)->AddBuffer( + static_cast(base), byte_size, memory_type, memory_type_id); + } + + return Status::Success; +} + +// +// InferenceResponse +// + +InferenceResponse::InferenceResponse( + const std::shared_ptr& model, const std::string& id, + const ResponseAllocator* allocator, void* alloc_userp, + TRITONSERVER_InferenceResponseCompleteFn_t response_fn, + void* response_userp, + const std::function< + void(std::unique_ptr&&, const uint32_t)>& delegator) + : model_(model), id_(id), allocator_(allocator), alloc_userp_(alloc_userp), + response_fn_(response_fn), response_userp_(response_userp), + response_delegator_(delegator), null_response_(false) +{ + // Skip allocator logic / references in unit test +} + +std::ostream& +operator<<(std::ostream& out, const InferenceResponse& response) +{ + out << "[0x" << std::addressof(response) << "] " + << "response id: " << response.Id() << std::endl; + + out << "status:" << response.ResponseStatus().AsString() << std::endl; + + return out; +} + +InferenceResponse::Output::~Output() +{ + Status status = ReleaseDataBuffer(); + if (!status.IsOk()) { + std::cerr << "[ERROR] failed to release buffer for output '" << name_ + << "': " << status.AsString(); + } +} + +Status +InferenceResponse::Output::ReleaseDataBuffer() +{ + if (allocated_buffer_ != nullptr) { + free(allocated_buffer_); + } + + allocated_buffer_ = nullptr; + buffer_attributes_.SetByteSize(0); + buffer_attributes_.SetMemoryType(TRITONSERVER_MEMORY_CPU); + buffer_attributes_.SetMemoryTypeId(0); + allocated_userp_ = nullptr; + + return Status::Success; +} + +// Same as defined in infer_response.cc +Status +InferenceResponse::Output::DataBuffer( + const void** buffer, size_t* buffer_byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id, + void** userp) const +{ + *buffer = allocated_buffer_; + *buffer_byte_size = buffer_attributes_.ByteSize(); + *memory_type = buffer_attributes_.MemoryType(); + *memory_type_id = buffer_attributes_.MemoryTypeId(); + *userp = allocated_userp_; + return Status::Success; +} + +// Simplified version of AllocateDataBuffer for CPU memory only +Status +InferenceResponse::Output::AllocateDataBuffer( + void** buffer, size_t buffer_byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id) +{ + if (allocated_buffer_ != nullptr) { + return Status( + Status::Code::ALREADY_EXISTS, + "allocated buffer for output '" + name_ + "' already exists"); + } + + // Simplifications - CPU memory only for now + if (*memory_type != TRITONSERVER_MEMORY_CPU || *memory_type_id != 0) { + return Status( + Status::Code::INTERNAL, "Only standard CPU memory supported for now"); + } + + // Allocate buffer to copy to + *buffer = malloc(buffer_byte_size); + if (buffer == nullptr || *buffer == nullptr) { + return Status( + Status::Code::INTERNAL, "buffer was nullptr in AllocateDataBuffer"); + } + + // Set relevant member variables for DataBuffer() to return + allocated_buffer_ = *buffer; + buffer_attributes_.SetByteSize(buffer_byte_size); + buffer_attributes_.SetMemoryType(*memory_type); + buffer_attributes_.SetMemoryTypeId(*memory_type_id); + allocated_userp_ = nullptr; + return Status::Success; +} + +Status +InferenceResponse::AddOutput( + const std::string& name, const inference::DataType datatype, + const std::vector& shape, InferenceResponse::Output** output) +{ + outputs_.emplace_back(name, datatype, shape, allocator_, alloc_userp_); + + if (output != nullptr) { + *output = std::addressof(outputs_.back()); + } + + return Status::Success; +} + +InferenceRequest::SequenceId::SequenceId() + : sequence_label_(""), sequence_index_(0), + id_type_(InferenceRequest::SequenceId::DataType::UINT64) +{ +} + +InferenceRequest::SequenceId::SequenceId(const std::string& sequence_label) + : sequence_label_(sequence_label), sequence_index_(0), + id_type_(InferenceRequest::SequenceId::DataType::STRING) +{ +} + +InferenceRequest::SequenceId::SequenceId(uint64_t sequence_index) + : sequence_label_(""), sequence_index_(sequence_index), + id_type_(InferenceRequest::SequenceId::DataType::UINT64) +{ +} + +}} // namespace triton::core + + +namespace { + +// Helpers +void +check_status(tc::Status status) +{ + ASSERT_TRUE(status.IsOk()) << "ERROR: " << status.Message(); +} + +void +cache_stats(std::unique_ptr& cache) +{ + std::cout << "Cache entries: " << cache->NumEntries() << std::endl; + std::cout << "Cache evictions: " << cache->NumEvictions() << std::endl; + std::cout << "Cache free bytes: " << cache->FreeBytes() << std::endl; + std::cout << "Cache alloc'd bytes: " << cache->AllocatedBytes() << std::endl; + std::cout << "Cache total bytes: " << cache->TotalBytes() << std::endl; +} + +void +reset_response( + std::unique_ptr* response, + tc::InferenceRequest* request) +{ + check_status(request->ResponseFactory()->CreateResponse(response)); +} + +// Only support 1-Dimensional data to keep it simple +struct Tensor { + std::string name; + std::vector data; +}; + +// Only support 1-Dimensional data to keep it simple +std::unique_ptr +GenerateResponse( + const tc::InferenceRequest* request, inference::DataType dtype, + TRITONSERVER_MemoryType memory_type, int64_t memory_type_id, + const std::vector& outputs) +{ + std::cout << "Create response object" << std::endl; + std::unique_ptr response; + check_status(request->ResponseFactory()->CreateResponse(&response)); + + std::cout << "Add output metadata to response object" << std::endl; + for (const auto& tensor : outputs) { + if (tensor.data.size() == 0) { + std::cout << "[ERROR] Can't generate a request with no output data" + << std::endl; + return nullptr; + } + + tc::InferenceResponse::Output* response_output = nullptr; + std::vector shape{1, -1}; + shape[1] = tensor.data.size(); + uint64_t output_size = sizeof(tensor.data[0]) * tensor.data.size(); + std::cout << "Output size bytes: " << output_size << std::endl; + check_status( + response->AddOutput(tensor.name, dtype, shape, &response_output)); + + std::cout << "Allocate output data buffer for response object" << std::endl; + void* buffer; + check_status(response_output->AllocateDataBuffer( + &buffer, output_size, &memory_type, &memory_type_id)); + if (buffer == nullptr) { + std::cout << "[ERROR] buffer was nullptr;" << std::endl; + return nullptr; + } + // Copy data from output to response buffer + std::memcpy(buffer, tensor.data.data(), output_size); + } + + return response; +} + +// Only support 1-Dimensional data to keep it simple +tc::InferenceRequest* +GenerateRequest( + tc::Model* model, uint64_t model_version, inference::DataType dtype, + TRITONSERVER_MemoryType memory_type, int64_t memory_type_id, + const std::vector& inputs, const std::string& request_id) +{ + auto request = new tc::InferenceRequest(model, model_version); + for (const auto& tensor : inputs) { + if (tensor.data.size() == 0) { + std::cout << "[ERROR] Can't generate a request with no input data" + << std::endl; + return nullptr; + } + + tc::InferenceRequest::Input* request_input = nullptr; + std::vector shape{1, -1}; + shape[1] = tensor.data.size(); + request->AddOriginalInput(tensor.name, dtype, shape, &request_input); + if (request_input == nullptr) { + std::cout << "[ERROR] request_input was nullptr" << std::endl; + return nullptr; + } + + uint64_t input_size = sizeof(tensor.data[0]) * tensor.data.size(); + request_input->AppendData( + tensor.data.data(), input_size, memory_type, memory_type_id); + } + // PrepareForInference for use of ImmutableInputs() + check_status(request->PrepareForInference()); + request->SetId(request_id); // for debugging purposes + return request; +} + +// Test Fixture +class RequestResponseCacheTest : public ::testing::Test { + protected: + void SetUp() override + { + // Sample input data + data0 = {1, 2, 3, 4}; + data1 = {5, 6, 7, 8}; + + // Sample input vectors + inputs0 = std::vector{{"input", data0}}; + inputs1 = std::vector{{"input", data1}}; + inputs2 = std::vector{{"input", data1}}; + inputs3 = std::vector{{"input0", data0}, {"input1", data1}}; + inputs4 = std::vector{{"input1", data1}, {"input0", data0}}; + + // Create three requests with same input name, two with same data, one with + // different data + request0 = GenerateRequest( + model, model_version, dtype, memory_type, memory_type_id, inputs0, + "request0"); + request1 = GenerateRequest( + model, model_version, dtype, memory_type, memory_type_id, inputs1, + "request1"); + request2 = GenerateRequest( + model, model_version, dtype, memory_type, memory_type_id, inputs2, + "request2"); + // Create two requests with the same two inputs but inserted in different + // order + request3 = GenerateRequest( + model, model_version, dtype, memory_type, memory_type_id, inputs3, + "request3"); + request4 = GenerateRequest( + model, model_version, dtype, memory_type, memory_type_id, inputs4, + "request4"); + // Verify requests were created correctly + ASSERT_NE(request0, nullptr); + ASSERT_NE(request1, nullptr); + ASSERT_NE(request2, nullptr); + ASSERT_NE(request3, nullptr); + ASSERT_NE(request4, nullptr); + + // Generate a set of unique requests to use for parallelism tests + for (size_t idx = 0; idx < thread_count; idx++) { + std::vector data(thread_count, static_cast(idx)); + std::vector inputs{Tensor{"input" + std::to_string(idx), data}}; + + std::string request_id = "unique" + std::to_string(idx); + std::cout << "Generating request: " << request_id << std::endl; + auto request = GenerateRequest( + model, model_version, dtype, memory_type, memory_type_id, inputs, + request_id); + ASSERT_NE(request, nullptr); + unique_requests.emplace_back(request); + } + ASSERT_EQ(unique_requests.size(), thread_count); + + // Sample outputs + Tensor output_tensor0 = {"output", data0}; + output0_size = sizeof(int) * data0.size(); + outputs0 = std::vector{output_tensor0}; + // Response of 100 ints, taking ~400 bytes at a time + data100 = std::vector(100, 0); + Tensor output_tensor100 = {"output", data100}; + outputs100 = std::vector{output_tensor100}; + + // Sample responses + response0 = GenerateResponse( + request0, dtype, memory_type, memory_type_id, outputs0); + ASSERT_NE(response0, nullptr); + response_400bytes = GenerateResponse( + request0, dtype, memory_type, memory_type_id, outputs100); + ASSERT_NE(response_400bytes, nullptr); + } + + void TearDown() override + { + delete request0; + delete request1; + delete request2; + delete request3; + delete request4; + for (auto r : unique_requests) { + delete r; + } + } + + public: + tc::Model* model = nullptr; + uint64_t model_version = 1; + inference::DataType dtype = inference::DataType::TYPE_INT32; + TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; + int64_t memory_type_id = 0; + size_t thread_count = 10; + uint64_t output0_size; + + std::vector data0, data1, data100; + std::vector inputs0, inputs1, inputs2, inputs3, inputs4, inputs100; + std::vector outputs0, outputs100; + tc::InferenceRequest *request0, *request1, *request2, *request3, *request4; + std::vector unique_requests; + std::unique_ptr response0, response_400bytes; +}; + +// Test hashing for consistency on same request +TEST_F(RequestResponseCacheTest, TestHashing) +{ + // Create cache + std::cout << "Create cache" << std::endl; + uint64_t cache_size = 4 * 1024 * 1024; + std::unique_ptr cache; + tc::RequestResponseCache::Create(cache_size, &cache); + + // Compare hashes + std::cout << "Compare hashes" << std::endl; + check_status(cache->HashAndSet(request0)); + check_status(cache->HashAndSet(request1)); + check_status(cache->HashAndSet(request2)); + check_status(cache->HashAndSet(request3)); + check_status(cache->HashAndSet(request4)); + + std::cout << "request0->CacheKey(): " << request0->CacheKey() << std::endl; + std::cout << "request1->CacheKey(): " << request1->CacheKey() << std::endl; + std::cout << "request2->CacheKey(): " << request2->CacheKey() << std::endl; + std::cout << "request3->CacheKey(): " << request3->CacheKey() << std::endl; + std::cout << "request4->CacheKey(): " << request4->CacheKey() << std::endl; + // Different input data should have different hashes + ASSERT_NE(request0->CacheKey(), request1->CacheKey()); + // Same input data should have same hashes + ASSERT_EQ(request1->CacheKey(), request2->CacheKey()); + // Two requests with same two inputs but added in different orders + ASSERT_EQ(request3->CacheKey(), request4->CacheKey()); +} + + +// Test cache size too large to initialize. +TEST_F(RequestResponseCacheTest, TestCacheSizeTooLarge) +{ + // Pick intentionally large cache size, expecting failure + constexpr uint64_t cache_size = ULLONG_MAX; + std::cout << "Create cache of size: " << cache_size << std::endl; + std::unique_ptr cache; + auto status = tc::RequestResponseCache::Create(cache_size, &cache); + ASSERT_FALSE(status.IsOk()) << "Creating cache of size " << cache_size + << " succeeded when it should fail."; +} + +// Test cache size too small to initialize. +// See following boost code for reference: +// - +// https://github.com/boostorg/interprocess/blob/41018201d6b7a34f38a0303a1ad591d978989cb8/include/boost/interprocess/managed_external_buffer.hpp#L75-L77 +// - +// https://github.com/boostorg/interprocess/blob/41018201d6b7a34f38a0303a1ad591d978989cb8/include/boost/interprocess/detail/managed_memory_impl.hpp#L172-L174 +TEST_F(RequestResponseCacheTest, TestCacheSizeTooSmall) +{ + // Pick intentionally small cache size, expecting failure + constexpr uint64_t cache_size = 1; + std::cout << "Create cache of size: " << cache_size << std::endl; + std::unique_ptr cache; + auto status = tc::RequestResponseCache::Create(cache_size, &cache); + ASSERT_FALSE(status.IsOk()) << "Creating cache of size " << cache_size + << " succeeded when it should fail."; +} + +// Test cache too small for entry +TEST_F(RequestResponseCacheTest, TestCacheSizeSmallerThanEntry) +{ + // Create cache + constexpr uint64_t cache_size = 1024; + std::cout << "Create cache of size: " << cache_size << std::endl; + std::unique_ptr cache; + tc::RequestResponseCache::Create(cache_size, &cache); + + // Set output data to be larger than cache size + // NOTE: This is not 1 byte larger than cache_size, the cache_size + 1 is to + // be clear it will always be larger than cache even if the dtype is changed. + std::vector large_data(cache_size + 1, 0); + std::cout << "Create large_response (larger than cache) of size: " + << large_data.size() << std::endl; + std::vector large_outputs{Tensor{"output", large_data}}; + auto large_response = GenerateResponse( + request0, dtype, memory_type, memory_type_id, large_outputs); + + std::cout << "Insert large_response into cache" << std::endl; + auto status = cache->Insert(*large_response, request0); + // We expect insertion to fail here since cache is too small + std::cout << status.Message() << std::endl; + ASSERT_FALSE(status.IsOk()) + << "Inserting item larger than cache succeeded when it should fail"; +} + +// Test hashing for consistency on same request +TEST_F(RequestResponseCacheTest, TestEviction) +{ + // Create cache + std::cout << "Create cache" << std::endl; + uint64_t cache_size = 1024; + std::unique_ptr cache; + tc::RequestResponseCache::Create(cache_size, &cache); + cache_stats(cache); + + std::cout << "Lookup unique_requests[0] in empty cache" << std::endl; + auto status = cache->Lookup(nullptr, unique_requests[0]); + // This hash not in cache yet + ASSERT_FALSE(status.IsOk()) + << "hash [" + std::to_string(unique_requests[0]->CacheKey()) + + "] should not be in cache"; + std::cout << "Insert response into cache" << std::endl; + check_status(cache->Insert(*response_400bytes, unique_requests[0])); + cache_stats(cache); + ASSERT_EQ(cache->NumEntries(), 1u); + ASSERT_EQ(cache->NumEvictions(), 0u); + + check_status(cache->Insert(*response_400bytes, unique_requests[1])); + cache_stats(cache); + ASSERT_EQ(cache->NumEntries(), 2u); + ASSERT_EQ(cache->NumEvictions(), 0u); + + check_status(cache->Insert(*response_400bytes, unique_requests[2])); + cache_stats(cache); + ASSERT_EQ(cache->NumEntries(), 2u); + ASSERT_EQ(cache->NumEvictions(), 1u); + + check_status(cache->Insert(*response_400bytes, unique_requests[3])); + cache_stats(cache); + ASSERT_EQ(cache->NumEntries(), 2u); + ASSERT_EQ(cache->NumEvictions(), 2u); +} + +// Test inserting into cache with multiple threads in parallel +// and asserting that the correct number of entries and evictions +// occurred based on cache and entry sizes +TEST_F(RequestResponseCacheTest, TestParallelInsertion) +{ + // Create cache + std::cout << "Create cache" << std::endl; + uint64_t cache_size = 1024; + std::unique_ptr cache; + tc::RequestResponseCache::Create(cache_size, &cache); + cache_stats(cache); + + // Create threads + std::vector threads; + std::cout << "Insert responses into cache with [" << thread_count + << "] threads in parallel" << std::endl; + for (size_t idx = 0; idx < thread_count; idx++) { + threads.emplace_back(std::thread( + &tc::RequestResponseCache::Insert, cache.get(), + std::ref(*response_400bytes), unique_requests[idx])); + } + + // Join threads + for (size_t idx = 0; idx < thread_count; idx++) { + std::cout << "Joining idx: " << idx << std::endl; + threads[idx].join(); + } + + // Cache size only has room for 2 entries of 100 ints, so we expect 2 entries + // and N-2 evictions for N threads + cache_stats(cache); + ASSERT_EQ(cache->NumEntries(), 2u) << "NumEntries: " << cache->NumEntries(); + ASSERT_EQ(cache->NumEvictions(), (uint64_t)(thread_count - 2u)) + << "NumEvictions: " << cache->NumEvictions(); +} + +// Test evicting from cache with multiple threads in parallel +// and asserting that the correct number of entries and evictions +// occurred +TEST_F(RequestResponseCacheTest, TestParallelEviction) +{ + // Create cache + std::cout << "Create cache" << std::endl; + uint64_t cache_size = 1024; + std::unique_ptr cache; + tc::RequestResponseCache::Create(cache_size, &cache); + cache_stats(cache); + + // Create threads + std::vector threads; + + // Insert [thread_count] entries into cache sequentially + for (size_t idx = 0; idx < thread_count; idx++) { + cache->Insert(*response0, unique_requests[idx]); + } + + // Assert all entries were put into cache and no evictions occurred yet + cache_stats(cache); + ASSERT_EQ(cache->NumEntries(), (uint64_t)thread_count) + << "NumEntries: " << cache->NumEntries(); + ASSERT_EQ(cache->NumEvictions(), 0u) + << "NumEvictions: " << cache->NumEvictions(); + + // Evict [thread_count] entries from cache in parallel + std::cout << "Evict from cache with [" << thread_count + << "] threads in parallel" << std::endl; + for (size_t idx = 0; idx < thread_count; idx++) { + threads.emplace_back( + std::thread(&tc::RequestResponseCache::Evict, cache.get())); + } + + // Join threads + for (size_t idx = 0; idx < thread_count; idx++) { + threads[idx].join(); + } + + // Assert all entries were evicted from cache and exactly [thread_count] + // evictions occurred + cache_stats(cache); + ASSERT_EQ(cache->NumEntries(), 0u) << "NumEntries: " << cache->NumEntries(); + ASSERT_EQ(cache->NumEvictions(), (uint64_t)thread_count) + << "NumEvictions: " << cache->NumEvictions(); +} + +// Test LRU ordering of cache +TEST_F(RequestResponseCacheTest, TestLRU) +{ + // Create cache + std::cout << "Create cache" << std::endl; + uint64_t cache_size = 1024; + std::unique_ptr cache; + tc::RequestResponseCache::Create(cache_size, &cache); + cache_stats(cache); + + // Insert 3 items into cache: 0, 1, 2 + check_status(cache->Insert(*response0, unique_requests[0])); + check_status(cache->Insert(*response0, unique_requests[1])); + check_status(cache->Insert(*response0, unique_requests[2])); + + // Verify items 0, 1, 2, in cache + reset_response(&response0, unique_requests[0]); + check_status(cache->Lookup(response0.get(), unique_requests[0])); + reset_response(&response0, unique_requests[1]); + check_status(cache->Lookup(response0.get(), unique_requests[1])); + reset_response(&response0, unique_requests[2]); + check_status(cache->Lookup(response0.get(), unique_requests[2])); + + // Evict item from cache, should be item 0 since it was looked up last + cache->Evict(); + // Assert Lookup for item 0 fails but items 1, 2 succeed + tc::Status status; + reset_response(&response0, unique_requests[0]); + status = cache->Lookup(response0.get(), unique_requests[0]); + ASSERT_FALSE(status.IsOk()); + reset_response(&response0, unique_requests[1]); + check_status(cache->Lookup(response0.get(), unique_requests[1])); + reset_response(&response0, unique_requests[2]); + check_status(cache->Lookup(response0.get(), unique_requests[2])); + + // Insert item 3, 4 + check_status(cache->Insert(*response0, unique_requests[3])); + check_status(cache->Insert(*response0, unique_requests[4])); + + // Evict twice, assert items 1 and 2 were evicted + cache->Evict(); + cache->Evict(); + reset_response(&response0, unique_requests[1]); + status = cache->Lookup(response0.get(), unique_requests[1]); + ASSERT_FALSE(status.IsOk()); + reset_response(&response0, unique_requests[2]); + status = cache->Lookup(response0.get(), unique_requests[2]); + ASSERT_FALSE(status.IsOk()); + + // Lookup items 3 and 4 + reset_response(&response0, unique_requests[3]); + check_status(cache->Lookup(response0.get(), unique_requests[3])); + reset_response(&response0, unique_requests[4]); + check_status(cache->Lookup(response0.get(), unique_requests[4])); + + // Evict, assert item 3 was evicted + cache->Evict(); + reset_response(&response0, unique_requests[3]); + status = cache->Lookup(response0.get(), unique_requests[3]); + ASSERT_FALSE(status.IsOk()); + reset_response(&response0, unique_requests[4]); + check_status(cache->Lookup(response0.get(), unique_requests[4])); +} + +// Test looking up from cache with multiple threads in parallel +// and asserting the responses were populated correctly +TEST_F(RequestResponseCacheTest, TestParallelLookup) +{ + // Create cache + std::cout << "Create cache" << std::endl; + uint64_t cache_size = 1024; + std::unique_ptr cache; + tc::RequestResponseCache::Create(cache_size, &cache); + cache_stats(cache); + + // Create threads + std::vector threads; + std::vector> responses; + + // Insert [thread_count] entries into cache sequentially + for (size_t idx = 0; idx < thread_count; idx++) { + // Create response for each thread to fill from cache + std::unique_ptr response; + check_status( + unique_requests[idx]->ResponseFactory()->CreateResponse(&response)); + responses.push_back(std::move(response)); + // Insert response for each thread + cache->Insert(*response0, unique_requests[idx]); + } + + // Assert all entries were put into cache and no evictions occurred yet + cache_stats(cache); + ASSERT_EQ(cache->NumEntries(), (uint64_t)thread_count) + << "NumEntries: " << cache->NumEntries(); + ASSERT_EQ(cache->NumEvictions(), 0u) + << "NumEvictions: " << cache->NumEvictions(); + + // Lookup [thread_count] entries from cache in parallel + std::cout << "Lookup from cache with [" << thread_count + << "] threads in parallel" << std::endl; + for (size_t idx = 0; idx < thread_count; idx++) { + threads.emplace_back(std::thread( + &tc::RequestResponseCache::Lookup, cache.get(), responses[idx].get(), + unique_requests[idx])); + } + + // Join threads + for (size_t idx = 0; idx < thread_count; idx++) { + threads[idx].join(); + } + + // Grab output from sample response for comparison + const auto& response0_output = response0->Outputs()[0]; + + // Verify output results from cache + for (size_t idx = 0; idx < thread_count; idx++) { + // Fetch output buffer details + const void* response_buffer = nullptr; + size_t response_byte_size = 0; + TRITONSERVER_MemoryType response_memory_type; + int64_t response_memory_type_id; + void* userp; + + // TODO: Handle multiple outputs more generically + const auto& response_test = responses[idx]; + for (const auto& response_test_output : response_test->Outputs()) { + ASSERT_EQ(response_test_output.Name(), response0_output.Name()); + ASSERT_EQ(response_test_output.DType(), response0_output.DType()); + ASSERT_EQ(response_test_output.Shape(), response0_output.Shape()); + check_status(response_test_output.DataBuffer( + &response_buffer, &response_byte_size, &response_memory_type, + &response_memory_type_id, &userp)); + + // TODO: Use Triton DType to cast buffer and compare outputs generically + int* cache_output = (int*)response_buffer; + std::cout << "Check output buffer data from cache entry for thread [" + << idx << "]:" << std::endl; + for (size_t i = 0; i < response_byte_size / sizeof(int); i++) { + std::cout << cache_output[i] << " == " << data0[i] << std::endl; + ASSERT_EQ(cache_output[i], data0[i]); + } + } + } +} + +// Test end-to-end flow of cache +TEST_F(RequestResponseCacheTest, TestEndToEnd) +{ + // Create cache + std::cout << "Create cache" << std::endl; + uint64_t cache_size = 256; + std::unique_ptr cache; + tc::RequestResponseCache::Create(cache_size, &cache); + cache_stats(cache); + + std::cout << "Lookup request0 in empty cache" << std::endl; + auto status = cache->Lookup(nullptr, request0); + // This hash not in cache yet + ASSERT_FALSE(status.IsOk()) << "hash [" + + std::to_string(request0->CacheKey()) + + "] should not be in cache"; + std::cout << "Insert response into cache with request0" << std::endl; + // Insertion should succeed + check_status(cache->Insert(*response0, request0)); + cache_stats(cache); + + // Check cache stats + auto total_lookup_latency = cache->TotalLookupLatencyNs(); + auto total_insertion_latency = cache->TotalInsertionLatencyNs(); + std::cout << "Total lookup latency: " << total_lookup_latency << std::endl; + std::cout << "Total insertion latency: " << total_insertion_latency + << std::endl; + ASSERT_TRUE(total_lookup_latency > 0) + << "ERROR: Total lookup latency should be non-zero"; + ASSERT_TRUE(total_insertion_latency > 0) + << "ERROR: Total insertion latency should be non-zero"; + + // Duplicate insertion should fail since request0 already exists in cache + status = cache->Insert(*response0, request0); + ASSERT_FALSE(status.IsOk()) + << "Inserting duplicate item in cache should fail"; + + // Create response to test cache lookup + std::cout << "Create response object into fill from cache" << std::endl; + std::unique_ptr response_test; + check_status(request0->ResponseFactory()->CreateResponse(&response_test)); + + // Lookup should now succeed + std::cout << "Lookup request0 in cache after insertion" << std::endl; + check_status(cache->Lookup(response_test.get(), request0)); + + // Check cache stats again + auto total_lookup_latency2 = cache->TotalLookupLatencyNs(); + auto total_insertion_latency2 = cache->TotalInsertionLatencyNs(); + std::cout << "Total lookup latency2: " << total_lookup_latency2 << std::endl; + std::cout << "Total insertion latency2: " << total_insertion_latency2 + << std::endl; + ASSERT_TRUE(total_lookup_latency2 > total_lookup_latency) + << "ERROR: Total lookup latency should increase"; + ASSERT_TRUE(total_insertion_latency2 > total_insertion_latency) + << "ERROR: Total insertion latency should increase"; + + // Grab output from sample response for comparison + const auto& response0_output = response0->Outputs()[0]; + + // Fetch output buffer details + const void* response_buffer = nullptr; + size_t response_byte_size = 0; + TRITONSERVER_MemoryType response_memory_type; + int64_t response_memory_type_id; + void* userp; + // TODO: How to handle different memory types? GPU vs CPU vs Pinned, etc. + // TODO: Handle multiple outputs more generically + for (const auto& response_test_output : response_test->Outputs()) { + ASSERT_EQ(response_test_output.Name(), response0_output.Name()); + ASSERT_EQ(response_test_output.DType(), response0_output.DType()); + ASSERT_EQ(response_test_output.Shape(), response0_output.Shape()); + check_status(response_test_output.DataBuffer( + &response_buffer, &response_byte_size, &response_memory_type, + &response_memory_type_id, &userp)); + } + + // TODO: Use Triton DType to cast buffer and compare outputs generically + int* cache_output = (int*)response_buffer; + std::cout << "Check output buffer data from cache entry:" << std::endl; + for (size_t i = 0; i < response_byte_size / sizeof(int); i++) { + std::cout << cache_output[i] << " == " << outputs0[0].data[i] << std::endl; + ASSERT_EQ(cache_output[i], outputs0[0].data[i]); + } + + // Simple Evict() test + ASSERT_EQ(cache->NumEntries(), 1u); + ASSERT_EQ(cache->NumEvictions(), 0u); + cache->Evict(); + ASSERT_EQ(cache->NumEntries(), 0u); + ASSERT_EQ(cache->NumEvictions(), 1u); + std::cout << "Done!" << std::endl; +} + +} // namespace + +int +main(int argc, char** argv) +{ +#ifdef TRITON_ENABLE_LOGGING + LOG_SET_VERBOSE(1); +#endif // TRITON_ENABLE_LOGGING + + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/3rdparty/core-r22.12/src/tritonserver.cc b/3rdparty/core-r22.12/src/tritonserver.cc new file mode 100644 index 0000000000000000000000000000000000000000..bdce0346525815f40233482f881e0ba9294d9e3e --- /dev/null +++ b/3rdparty/core-r22.12/src/tritonserver.cc @@ -0,0 +1,3066 @@ +// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include +#include "buffer_attributes.h" +#include "cuda_utils.h" +#include "infer_parameter.h" +#include "infer_request.h" +#include "infer_response.h" +#include "infer_stats.h" +#include "metric_family.h" +#include "metrics.h" +#include "model.h" +#include "model_config_utils.h" +#include "model_repository_manager.h" +#include "rate_limiter.h" +#include "response_allocator.h" +#include "server.h" +#include "server_message.h" +#include "status.h" +#include "triton/common/logging.h" +#include "triton/common/model_config.h" +#include "triton/common/nvtx.h" +#include "triton/common/table_printer.h" +#include "triton/common/triton_json.h" +#include "tritonserver_apis.h" + +// For unknown reason, windows will not export some functions declared +// with dllexport in tritonrepoagent.h and tritonbackend.h. To get +// those functions exported it is (also?) necessary to mark the +// definitions in this file with dllexport as well. The TRITONSERVER_* +// functions are getting exported but for consistency adding the +// declspec to these definitions as well. +#if defined(_MSC_VER) +#define TRITONAPI_DECLSPEC __declspec(dllexport) +#elif defined(__GNUC__) +#define TRITONAPI_DECLSPEC __attribute__((__visibility__("default"))) +#else +#define TRITONAPI_DECLSPEC +#endif + +namespace tc = triton::core; + +namespace { + +std::string +ResourceString(const std::string& name, const int count, const int device_id) +{ + return std::string( + "{\"name\":\"" + name + "\", \"count\":" + std::to_string(count) + + " \"device\":" + std::to_string(device_id) + "}"); +} + +std::string +RateLimitModeToString(const tc::RateLimitMode rate_limit_mode) +{ + std::string rl_mode_str(""); + switch (rate_limit_mode) { + case tc::RateLimitMode::RL_EXEC_COUNT: { + rl_mode_str = "EXEC_COUNT"; + break; + } + case tc::RateLimitMode::RL_OFF: { + rl_mode_str = "OFF"; + break; + } + } + return rl_mode_str; +} + +// +// TritonServerError +// +// Implementation for TRITONSERVER_Error. +// +class TritonServerError { + public: + static TRITONSERVER_Error* Create( + TRITONSERVER_Error_Code code, const char* msg); + static TRITONSERVER_Error* Create( + TRITONSERVER_Error_Code code, const std::string& msg); + static TRITONSERVER_Error* Create(const tc::Status& status); + + TRITONSERVER_Error_Code Code() const { return code_; } + const std::string& Message() const { return msg_; } + + private: + TritonServerError(TRITONSERVER_Error_Code code, const std::string& msg) + : code_(code), msg_(msg) + { + } + TritonServerError(TRITONSERVER_Error_Code code, const char* msg) + : code_(code), msg_(msg) + { + } + + TRITONSERVER_Error_Code code_; + const std::string msg_; +}; + +TRITONSERVER_Error* +TritonServerError::Create(TRITONSERVER_Error_Code code, const char* msg) +{ + return reinterpret_cast( + new TritonServerError(code, msg)); +} + +TRITONSERVER_Error* +TritonServerError::Create(TRITONSERVER_Error_Code code, const std::string& msg) +{ + return reinterpret_cast( + new TritonServerError(code, msg)); +} + +TRITONSERVER_Error* +TritonServerError::Create(const tc::Status& status) +{ + // If 'status' is success then return nullptr as that indicates + // success + if (status.IsOk()) { + return nullptr; + } + + return Create( + tc::StatusCodeToTritonCode(status.StatusCode()), status.Message()); +} + +#define RETURN_IF_STATUS_ERROR(S) \ + do { \ + const tc::Status& status__ = (S); \ + if (!status__.IsOk()) { \ + return TritonServerError::Create(status__); \ + } \ + } while (false) + +// +// TritonServerMetrics +// +// Implementation for TRITONSERVER_Metrics. +// +class TritonServerMetrics { + public: + TritonServerMetrics() = default; + TRITONSERVER_Error* Serialize(const char** base, size_t* byte_size); + + private: + std::string serialized_; +}; + +TRITONSERVER_Error* +TritonServerMetrics::Serialize(const char** base, size_t* byte_size) +{ +#ifdef TRITON_ENABLE_METRICS + serialized_ = tc::Metrics::SerializedMetrics(); + *base = serialized_.c_str(); + *byte_size = serialized_.size(); + return nullptr; // Success +#else + *base = nullptr; + *byte_size = 0; + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "metrics not supported"); +#endif // TRITON_ENABLE_METRICS +} + +// +// TritonServerOptions +// +// Implementation for TRITONSERVER_ServerOptions. +// +class TritonServerOptions { + public: + TritonServerOptions(); + + const std::string& ServerId() const { return server_id_; } + void SetServerId(const char* id) { server_id_ = id; } + + const std::set& ModelRepositoryPaths() const + { + return repo_paths_; + } + void SetModelRepositoryPath(const char* p) { repo_paths_.insert(p); } + + tc::ModelControlMode ModelControlMode() const { return model_control_mode_; } + void SetModelControlMode(tc::ModelControlMode m) { model_control_mode_ = m; } + + const std::set& StartupModels() const { return models_; } + void SetStartupModel(const char* m) { models_.insert(m); } + + bool ExitOnError() const { return exit_on_error_; } + void SetExitOnError(bool b) { exit_on_error_ = b; } + + bool StrictModelConfig() const { return strict_model_config_; } + void SetStrictModelConfig(bool b) { strict_model_config_ = b; } + + tc::RateLimitMode RateLimiterMode() const { return rate_limit_mode_; } + void SetRateLimiterMode(tc::RateLimitMode m) { rate_limit_mode_ = m; } + + TRITONSERVER_Error* AddRateLimiterResource( + const std::string& resource, const size_t count, const int device); + + // The resource map is the map from device id to the map of + // of resources with their respective counts for that device. + const tc::RateLimiter::ResourceMap& RateLimiterResources() const + { + return rate_limit_resource_map_; + } + + uint64_t PinnedMemoryPoolByteSize() const { return pinned_memory_pool_size_; } + void SetPinnedMemoryPoolByteSize(uint64_t s) { pinned_memory_pool_size_ = s; } + + uint64_t ResponseCacheByteSize() const { return response_cache_byte_size_; } + void SetResponseCacheByteSize(uint64_t s) { response_cache_byte_size_ = s; } + + const std::map& CudaMemoryPoolByteSize() const + { + return cuda_memory_pool_size_; + } + void SetCudaMemoryPoolByteSize(int id, uint64_t s) + { + cuda_memory_pool_size_[id] = s; + } + + double MinSupportedComputeCapability() const + { + return min_compute_capability_; + } + void SetMinSupportedComputeCapability(double c) + { + min_compute_capability_ = c; + } + + bool StrictReadiness() const { return strict_readiness_; } + void SetStrictReadiness(bool b) { strict_readiness_ = b; } + + unsigned int ExitTimeout() const { return exit_timeout_; } + void SetExitTimeout(unsigned int t) { exit_timeout_ = t; } + + unsigned int BufferManagerThreadCount() const + { + return buffer_manager_thread_count_; + } + void SetBufferManagerThreadCount(unsigned int c) + { + buffer_manager_thread_count_ = c; + } + + unsigned int ModelLoadThreadCount() const { return model_load_thread_count_; } + void SetModelLoadThreadCount(unsigned int c) { model_load_thread_count_ = c; } + + bool Metrics() const { return metrics_; } + void SetMetrics(bool b) { metrics_ = b; } + + bool GpuMetrics() const { return gpu_metrics_; } + void SetGpuMetrics(bool b) { gpu_metrics_ = b; } + + bool CpuMetrics() const { return cpu_metrics_; } + void SetCpuMetrics(bool b) { cpu_metrics_ = b; } + + uint64_t MetricsInterval() const { return metrics_interval_; } + void SetMetricsInterval(uint64_t m) { metrics_interval_ = m; } + + const std::string& BackendDir() const { return backend_dir_; } + void SetBackendDir(const std::string& bd) { backend_dir_ = bd; } + + const std::string& RepoAgentDir() const { return repoagent_dir_; } + void SetRepoAgentDir(const std::string& rad) { repoagent_dir_ = rad; } + + // The backend config map is a map from backend name to the + // setting=value pairs for that backend. The empty backend name ("") + // is used to communicate configuration information that is used + // internally. + const triton::common::BackendCmdlineConfigMap& BackendCmdlineConfigMap() const + { + return backend_cmdline_config_map_; + } + TRITONSERVER_Error* AddBackendConfig( + const std::string& backend_name, const std::string& setting, + const std::string& value); + + TRITONSERVER_Error* SetHostPolicy( + const std::string& policy_name, const std::string& setting, + const std::string& value); + const triton::common::HostPolicyCmdlineConfigMap& HostPolicyCmdlineConfigMap() + const + { + return host_policy_map_; + } + + private: + std::string server_id_; + std::set repo_paths_; + tc::ModelControlMode model_control_mode_; + std::set models_; + bool exit_on_error_; + bool strict_model_config_; + bool strict_readiness_; + tc::RateLimitMode rate_limit_mode_; + tc::RateLimiter::ResourceMap rate_limit_resource_map_; + bool metrics_; + bool gpu_metrics_; + bool cpu_metrics_; + uint64_t metrics_interval_; + unsigned int exit_timeout_; + uint64_t pinned_memory_pool_size_; + uint64_t response_cache_byte_size_; + unsigned int buffer_manager_thread_count_; + unsigned int model_load_thread_count_; + std::map cuda_memory_pool_size_; + double min_compute_capability_; + std::string backend_dir_; + std::string repoagent_dir_; + triton::common::BackendCmdlineConfigMap backend_cmdline_config_map_; + triton::common::HostPolicyCmdlineConfigMap host_policy_map_; +}; + +TritonServerOptions::TritonServerOptions() + : server_id_("triton"), + model_control_mode_(tc::ModelControlMode::MODE_POLL), + exit_on_error_(true), strict_model_config_(true), strict_readiness_(true), + rate_limit_mode_(tc::RateLimitMode::RL_OFF), metrics_(true), + gpu_metrics_(true), cpu_metrics_(true), metrics_interval_(2000), + exit_timeout_(30), pinned_memory_pool_size_(1 << 28), + response_cache_byte_size_(0), buffer_manager_thread_count_(0), + model_load_thread_count_( + std::max(2u, 2 * std::thread::hardware_concurrency())), +#ifdef TRITON_ENABLE_GPU + min_compute_capability_(TRITON_MIN_COMPUTE_CAPABILITY), +#else + min_compute_capability_(0), +#endif // TRITON_ENABLE_GPU + backend_dir_("/opt/tritonserver/backends"), + repoagent_dir_("/opt/tritonserver/repoagents") +{ +#ifndef TRITON_ENABLE_METRICS + metrics_ = false; + gpu_metrics_ = false; + cpu_metrics_ = false; +#endif // TRITON_ENABLE_METRICS + +#ifndef TRITON_ENABLE_METRICS_GPU + gpu_metrics_ = false; +#endif // TRITON_ENABLE_METRICS_GPU + +#ifndef TRITON_ENABLE_METRICS_CPU + cpu_metrics_ = false; +#endif // TRITON_ENABLE_METRICS_CPU +} + +TRITONSERVER_Error* +TritonServerOptions::AddRateLimiterResource( + const std::string& name, const size_t count, const int device) +{ + auto ditr = rate_limit_resource_map_.find(device); + if (ditr == rate_limit_resource_map_.end()) { + ditr = rate_limit_resource_map_ + .emplace(device, std::map()) + .first; + } + auto ritr = ditr->second.find(name); + if (ritr == ditr->second.end()) { + ditr->second.emplace(name, count).first; + } else { + // If already present then store the minimum of the two. + if (ritr->second > count) { + ritr->second = count; + } + } + + return nullptr; // success +} + +TRITONSERVER_Error* +TritonServerOptions::AddBackendConfig( + const std::string& backend_name, const std::string& setting, + const std::string& value) +{ + triton::common::BackendCmdlineConfig& cc = + backend_cmdline_config_map_[backend_name]; + cc.push_back(std::make_pair(setting, value)); + + return nullptr; // success +} + +TRITONSERVER_Error* +TritonServerOptions::SetHostPolicy( + const std::string& policy_name, const std::string& setting, + const std::string& value) +{ + // Check if supported setting is passed + if ((setting != "numa-node") && (setting != "cpu-cores")) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + std::string( + "Unsupported host policy setting '" + setting + + "' is specified, supported settings are 'numa-node', 'cpu-cores'") + .c_str()); + } + + triton::common::HostPolicyCmdlineConfig& hp = host_policy_map_[policy_name]; + hp[setting] = value; + + return nullptr; // success +} + +#define SetDurationStat(DOC, PARENT, STAT_NAME, COUNT, NS) \ + do { \ + triton::common::TritonJson::Value dstat( \ + DOC, triton::common::TritonJson::ValueType::OBJECT); \ + dstat.AddUInt("count", (COUNT)); \ + dstat.AddUInt("ns", (NS)); \ + PARENT.Add(STAT_NAME, std::move(dstat)); \ + } while (false) + +} // namespace + +extern "C" { + +// +// TRITONSERVER API Version +// +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ApiVersion(uint32_t* major, uint32_t* minor) +{ + *major = TRITONSERVER_API_VERSION_MAJOR; + *minor = TRITONSERVER_API_VERSION_MINOR; + return nullptr; // success +} + +// +// TRITONSERVER_DataType +// +TRITONAPI_DECLSPEC const char* +TRITONSERVER_DataTypeString(TRITONSERVER_DataType datatype) +{ + switch (datatype) { + case TRITONSERVER_TYPE_BOOL: + return "BOOL"; + case TRITONSERVER_TYPE_UINT8: + return "UINT8"; + case TRITONSERVER_TYPE_UINT16: + return "UINT16"; + case TRITONSERVER_TYPE_UINT32: + return "UINT32"; + case TRITONSERVER_TYPE_UINT64: + return "UINT64"; + case TRITONSERVER_TYPE_INT8: + return "INT8"; + case TRITONSERVER_TYPE_INT16: + return "INT16"; + case TRITONSERVER_TYPE_INT32: + return "INT32"; + case TRITONSERVER_TYPE_INT64: + return "INT64"; + case TRITONSERVER_TYPE_FP16: + return "FP16"; + case TRITONSERVER_TYPE_FP32: + return "FP32"; + case TRITONSERVER_TYPE_FP64: + return "FP64"; + case TRITONSERVER_TYPE_BYTES: + return "BYTES"; + case TRITONSERVER_TYPE_BF16: + return "BF16"; + default: + break; + } + + return ""; +} + +TRITONAPI_DECLSPEC TRITONSERVER_DataType +TRITONSERVER_StringToDataType(const char* dtype) +{ + const size_t len = strlen(dtype); + return tc::DataTypeToTriton( + triton::common::ProtocolStringToDataType(dtype, len)); +} + +TRITONAPI_DECLSPEC uint32_t +TRITONSERVER_DataTypeByteSize(TRITONSERVER_DataType datatype) +{ + switch (datatype) { + case TRITONSERVER_TYPE_BOOL: + case TRITONSERVER_TYPE_INT8: + case TRITONSERVER_TYPE_UINT8: + return 1; + case TRITONSERVER_TYPE_INT16: + case TRITONSERVER_TYPE_UINT16: + case TRITONSERVER_TYPE_FP16: + case TRITONSERVER_TYPE_BF16: + return 2; + case TRITONSERVER_TYPE_INT32: + case TRITONSERVER_TYPE_UINT32: + case TRITONSERVER_TYPE_FP32: + return 4; + case TRITONSERVER_TYPE_INT64: + case TRITONSERVER_TYPE_UINT64: + case TRITONSERVER_TYPE_FP64: + return 8; + case TRITONSERVER_TYPE_BYTES: + return 0; + default: + break; + } + + return 0; +} + +// +// TRITONSERVER_MemoryType +// +TRITONAPI_DECLSPEC const char* +TRITONSERVER_MemoryTypeString(TRITONSERVER_MemoryType memtype) +{ + switch (memtype) { + case TRITONSERVER_MEMORY_CPU: + return "CPU"; + case TRITONSERVER_MEMORY_CPU_PINNED: + return "CPU_PINNED"; + case TRITONSERVER_MEMORY_GPU: + return "GPU"; + default: + break; + } + + return ""; +} + +// +// TRITONSERVER_Parameter +// +TRITONAPI_DECLSPEC const char* +TRITONSERVER_ParameterTypeString(TRITONSERVER_ParameterType paramtype) +{ + switch (paramtype) { + case TRITONSERVER_PARAMETER_STRING: + return "STRING"; + case TRITONSERVER_PARAMETER_INT: + return "INT"; + case TRITONSERVER_PARAMETER_BOOL: + return "BOOL"; + default: + break; + } + + return ""; +} + +TRITONAPI_DECLSPEC TRITONSERVER_Parameter* +TRITONSERVER_ParameterNew( + const char* name, const TRITONSERVER_ParameterType type, const void* value) +{ + std::unique_ptr lparam; + switch (type) { + case TRITONSERVER_PARAMETER_STRING: + lparam.reset(new tc::InferenceParameter( + name, reinterpret_cast(value))); + break; + case TRITONSERVER_PARAMETER_INT: + lparam.reset(new tc::InferenceParameter( + name, *reinterpret_cast(value))); + break; + case TRITONSERVER_PARAMETER_BOOL: + lparam.reset(new tc::InferenceParameter( + name, *reinterpret_cast(value))); + break; + default: + break; + } + return reinterpret_cast(lparam.release()); +} + +TRITONAPI_DECLSPEC TRITONSERVER_Parameter* +TRITONSERVER_ParameterBytesNew( + const char* name, const void* byte_ptr, const uint64_t size) +{ + std::unique_ptr lparam( + new tc::InferenceParameter(name, byte_ptr, size)); + return reinterpret_cast(lparam.release()); +} + +TRITONAPI_DECLSPEC void +TRITONSERVER_ParameterDelete(TRITONSERVER_Parameter* parameter) +{ + delete reinterpret_cast(parameter); +} + +// +// TRITONSERVER_InstanceGroupKind +// +TRITONAPI_DECLSPEC const char* +TRITONSERVER_InstanceGroupKindString(TRITONSERVER_InstanceGroupKind kind) +{ + switch (kind) { + case TRITONSERVER_INSTANCEGROUPKIND_AUTO: + return "AUTO"; + case TRITONSERVER_INSTANCEGROUPKIND_CPU: + return "CPU"; + case TRITONSERVER_INSTANCEGROUPKIND_GPU: + return "GPU"; + case TRITONSERVER_INSTANCEGROUPKIND_MODEL: + return "MODEL"; + default: + break; + } + + return ""; +} + +// +// TRITONSERVER_Log +// +TRITONAPI_DECLSPEC bool +TRITONSERVER_LogIsEnabled(TRITONSERVER_LogLevel level) +{ + switch (level) { + case TRITONSERVER_LOG_INFO: + return LOG_INFO_IS_ON; + case TRITONSERVER_LOG_WARN: + return LOG_WARNING_IS_ON; + case TRITONSERVER_LOG_ERROR: + return LOG_ERROR_IS_ON; + case TRITONSERVER_LOG_VERBOSE: + return LOG_VERBOSE_IS_ON(1); + } + + return false; +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_LogMessage( + TRITONSERVER_LogLevel level, const char* filename, const int line, + const char* msg) +{ + switch (level) { + case TRITONSERVER_LOG_INFO: + LOG_INFO_FL(filename, line) << msg; + return nullptr; + case TRITONSERVER_LOG_WARN: + LOG_WARNING_FL(filename, line) << msg; + return nullptr; + case TRITONSERVER_LOG_ERROR: + LOG_ERROR_FL(filename, line) << msg; + return nullptr; + case TRITONSERVER_LOG_VERBOSE: + LOG_VERBOSE_FL(1, filename, line) << msg; + return nullptr; + default: + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string("unknown logging level '" + std::to_string(level) + "'") + .c_str()); + } +} + +// +// TRITONSERVER_Error +// +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ErrorNew(TRITONSERVER_Error_Code code, const char* msg) +{ + return reinterpret_cast( + TritonServerError::Create(code, msg)); +} + +TRITONAPI_DECLSPEC void +TRITONSERVER_ErrorDelete(TRITONSERVER_Error* error) +{ + TritonServerError* lerror = reinterpret_cast(error); + delete lerror; +} + +TRITONSERVER_Error_Code +TRITONSERVER_ErrorCode(TRITONSERVER_Error* error) +{ + TritonServerError* lerror = reinterpret_cast(error); + return lerror->Code(); +} + +TRITONAPI_DECLSPEC const char* +TRITONSERVER_ErrorCodeString(TRITONSERVER_Error* error) +{ + TritonServerError* lerror = reinterpret_cast(error); + return tc::Status::CodeString(tc::TritonCodeToStatusCode(lerror->Code())); +} + +TRITONAPI_DECLSPEC const char* +TRITONSERVER_ErrorMessage(TRITONSERVER_Error* error) +{ + TritonServerError* lerror = reinterpret_cast(error); + return lerror->Message().c_str(); +} + +// +// TRITONSERVER_ResponseAllocator +// +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ResponseAllocatorNew( + TRITONSERVER_ResponseAllocator** allocator, + TRITONSERVER_ResponseAllocatorAllocFn_t alloc_fn, + TRITONSERVER_ResponseAllocatorReleaseFn_t release_fn, + TRITONSERVER_ResponseAllocatorStartFn_t start_fn) +{ + *allocator = reinterpret_cast( + new tc::ResponseAllocator(alloc_fn, release_fn, start_fn)); + return nullptr; // Success +} + +TRITONSERVER_Error* +TRITONSERVER_ResponseAllocatorSetQueryFunction( + TRITONSERVER_ResponseAllocator* allocator, + TRITONSERVER_ResponseAllocatorQueryFn_t query_fn) +{ + reinterpret_cast(allocator)->SetQueryFunction( + query_fn); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ResponseAllocatorSetBufferAttributesFunction( + TRITONSERVER_ResponseAllocator* allocator, + TRITONSERVER_ResponseAllocatorBufferAttributesFn_t buffer_attributes_fn) +{ + reinterpret_cast(allocator) + ->SetBufferAttributesFunction(buffer_attributes_fn); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ResponseAllocatorDelete(TRITONSERVER_ResponseAllocator* allocator) +{ + tc::ResponseAllocator* lalloc = + reinterpret_cast(allocator); + delete lalloc; + return nullptr; // Success +} + +// +// TRITONSERVER_Message +// +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_MessageNewFromSerializedJson( + TRITONSERVER_Message** message, const char* base, size_t byte_size) +{ + *message = reinterpret_cast( + new tc::TritonServerMessage({base, byte_size})); + return nullptr; +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_MessageDelete(TRITONSERVER_Message* message) +{ + tc::TritonServerMessage* lmessage = + reinterpret_cast(message); + delete lmessage; + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_MessageSerializeToJson( + TRITONSERVER_Message* message, const char** base, size_t* byte_size) +{ + tc::TritonServerMessage* lmessage = + reinterpret_cast(message); + lmessage->Serialize(base, byte_size); + return nullptr; // Success +} + +// +// TRITONSERVER_Metrics +// +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_MetricsDelete(TRITONSERVER_Metrics* metrics) +{ + TritonServerMetrics* lmetrics = + reinterpret_cast(metrics); + delete lmetrics; + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_MetricsFormatted( + TRITONSERVER_Metrics* metrics, TRITONSERVER_MetricFormat format, + const char** base, size_t* byte_size) +{ + TritonServerMetrics* lmetrics = + reinterpret_cast(metrics); + + switch (format) { + case TRITONSERVER_METRIC_PROMETHEUS: { + return lmetrics->Serialize(base, byte_size); + } + + default: + break; + } + + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string("unknown metrics format '" + std::to_string(format) + "'") + .c_str()); +} + +// +// TRITONSERVER_InferenceTrace +// +TRITONAPI_DECLSPEC const char* +TRITONSERVER_InferenceTraceLevelString(TRITONSERVER_InferenceTraceLevel level) +{ + switch (level) { + case TRITONSERVER_TRACE_LEVEL_DISABLED: + return "DISABLED"; + case TRITONSERVER_TRACE_LEVEL_MIN: + return "MIN"; + case TRITONSERVER_TRACE_LEVEL_MAX: + return "MAX"; + case TRITONSERVER_TRACE_LEVEL_TIMESTAMPS: + return "TIMESTAMPS"; + case TRITONSERVER_TRACE_LEVEL_TENSORS: + return "TENSORS"; + } + + return ""; +} + +TRITONAPI_DECLSPEC const char* +TRITONSERVER_InferenceTraceActivityString( + TRITONSERVER_InferenceTraceActivity activity) +{ + switch (activity) { + case TRITONSERVER_TRACE_REQUEST_START: + return "REQUEST_START"; + case TRITONSERVER_TRACE_QUEUE_START: + return "QUEUE_START"; + case TRITONSERVER_TRACE_COMPUTE_START: + return "COMPUTE_START"; + case TRITONSERVER_TRACE_COMPUTE_INPUT_END: + return "COMPUTE_INPUT_END"; + case TRITONSERVER_TRACE_COMPUTE_OUTPUT_START: + return "COMPUTE_OUTPUT_START"; + case TRITONSERVER_TRACE_COMPUTE_END: + return "COMPUTE_END"; + case TRITONSERVER_TRACE_REQUEST_END: + return "REQUEST_END"; + case TRITONSERVER_TRACE_TENSOR_QUEUE_INPUT: + return "TENSOR_QUEUE_INPUT"; + case TRITONSERVER_TRACE_TENSOR_BACKEND_INPUT: + return "TENSOR_BACKEND_INPUT"; + case TRITONSERVER_TRACE_TENSOR_BACKEND_OUTPUT: + return "TENSOR_BACKEND_OUTPUT"; + } + + return ""; +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceTraceNew( + TRITONSERVER_InferenceTrace** trace, TRITONSERVER_InferenceTraceLevel level, + uint64_t parent_id, TRITONSERVER_InferenceTraceActivityFn_t activity_fn, + TRITONSERVER_InferenceTraceReleaseFn_t release_fn, void* trace_userp) +{ +#ifdef TRITON_ENABLE_TRACING + if ((level & TRITONSERVER_TRACE_LEVEL_MIN) > 0) { + level = static_cast( + (level ^ TRITONSERVER_TRACE_LEVEL_MIN) | + TRITONSERVER_TRACE_LEVEL_TIMESTAMPS); + } + if ((level & TRITONSERVER_TRACE_LEVEL_MAX) > 0) { + level = static_cast( + (level ^ TRITONSERVER_TRACE_LEVEL_MAX) | + TRITONSERVER_TRACE_LEVEL_TIMESTAMPS); + } + tc::InferenceTrace* ltrace = new tc::InferenceTrace( + level, parent_id, activity_fn, nullptr, release_fn, trace_userp); + *trace = reinterpret_cast(ltrace); + return nullptr; // Success +#else + *trace = nullptr; + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "inference tracing not supported"); +#endif // TRITON_ENABLE_TRACING +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceTraceTensorNew( + TRITONSERVER_InferenceTrace** trace, TRITONSERVER_InferenceTraceLevel level, + uint64_t parent_id, TRITONSERVER_InferenceTraceActivityFn_t activity_fn, + TRITONSERVER_InferenceTraceTensorActivityFn_t tensor_activity_fn, + TRITONSERVER_InferenceTraceReleaseFn_t release_fn, void* trace_userp) +{ +#ifdef TRITON_ENABLE_TRACING + if ((level & TRITONSERVER_TRACE_LEVEL_MIN) > 0) { + level = static_cast( + (level ^ TRITONSERVER_TRACE_LEVEL_MIN) | + TRITONSERVER_TRACE_LEVEL_TIMESTAMPS); + } + if ((level & TRITONSERVER_TRACE_LEVEL_MAX) > 0) { + level = static_cast( + (level ^ TRITONSERVER_TRACE_LEVEL_MAX) | + TRITONSERVER_TRACE_LEVEL_TIMESTAMPS); + } + tc::InferenceTrace* ltrace = new tc::InferenceTrace( + level, parent_id, activity_fn, tensor_activity_fn, release_fn, + trace_userp); + *trace = reinterpret_cast(ltrace); + return nullptr; // Success +#else + *trace = nullptr; + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "inference tracing not supported"); +#endif // TRITON_ENABLE_TRACING +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceTraceDelete(TRITONSERVER_InferenceTrace* trace) +{ +#ifdef TRITON_ENABLE_TRACING + tc::InferenceTrace* ltrace = reinterpret_cast(trace); + delete ltrace; + return nullptr; // Success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "inference tracing not supported"); +#endif // TRITON_ENABLE_TRACING +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceTraceId(TRITONSERVER_InferenceTrace* trace, uint64_t* id) +{ +#ifdef TRITON_ENABLE_TRACING + tc::InferenceTrace* ltrace = reinterpret_cast(trace); + *id = ltrace->Id(); + return nullptr; // Success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "inference tracing not supported"); +#endif // TRITON_ENABLE_TRACING +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceTraceParentId( + TRITONSERVER_InferenceTrace* trace, uint64_t* parent_id) +{ +#ifdef TRITON_ENABLE_TRACING + tc::InferenceTrace* ltrace = reinterpret_cast(trace); + *parent_id = ltrace->ParentId(); + return nullptr; // Success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "inference tracing not supported"); +#endif // TRITON_ENABLE_TRACING +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceTraceModelName( + TRITONSERVER_InferenceTrace* trace, const char** model_name) +{ +#ifdef TRITON_ENABLE_TRACING + tc::InferenceTrace* ltrace = reinterpret_cast(trace); + *model_name = ltrace->ModelName().c_str(); + return nullptr; // Success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "inference tracing not supported"); +#endif // TRITON_ENABLE_TRACING +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceTraceModelVersion( + TRITONSERVER_InferenceTrace* trace, int64_t* model_version) +{ +#ifdef TRITON_ENABLE_TRACING + tc::InferenceTrace* ltrace = reinterpret_cast(trace); + *model_version = ltrace->ModelVersion(); + return nullptr; // Success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "inference tracing not supported"); +#endif // TRITON_ENABLE_TRACING +} + +// +// TRITONSERVER_ServerOptions +// +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsNew(TRITONSERVER_ServerOptions** options) +{ + *options = + reinterpret_cast(new TritonServerOptions()); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsDelete(TRITONSERVER_ServerOptions* options) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + delete loptions; + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetServerId( + TRITONSERVER_ServerOptions* options, const char* server_id) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetServerId(server_id); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetModelRepositoryPath( + TRITONSERVER_ServerOptions* options, const char* model_repository_path) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetModelRepositoryPath(model_repository_path); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetModelControlMode( + TRITONSERVER_ServerOptions* options, TRITONSERVER_ModelControlMode mode) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + + // convert mode from TRITONSERVER_ to triton::core + switch (mode) { + case TRITONSERVER_MODEL_CONTROL_NONE: { + loptions->SetModelControlMode(tc::ModelControlMode::MODE_NONE); + break; + } + case TRITONSERVER_MODEL_CONTROL_POLL: { + loptions->SetModelControlMode(tc::ModelControlMode::MODE_POLL); + break; + } + case TRITONSERVER_MODEL_CONTROL_EXPLICIT: { + loptions->SetModelControlMode(tc::ModelControlMode::MODE_EXPLICIT); + break; + } + default: { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string("unknown control mode '" + std::to_string(mode) + "'") + .c_str()); + } + } + + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetStartupModel( + TRITONSERVER_ServerOptions* options, const char* model_name) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetStartupModel(model_name); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetExitOnError( + TRITONSERVER_ServerOptions* options, bool exit) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetExitOnError(exit); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetStrictModelConfig( + TRITONSERVER_ServerOptions* options, bool strict) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetStrictModelConfig(strict); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetRateLimiterMode( + TRITONSERVER_ServerOptions* options, TRITONSERVER_RateLimitMode mode) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + + // convert mode from TRITONSERVER_ to triton::core + switch (mode) { + case TRITONSERVER_RATE_LIMIT_EXEC_COUNT: { + loptions->SetRateLimiterMode(tc::RateLimitMode::RL_EXEC_COUNT); + break; + } + case TRITONSERVER_RATE_LIMIT_OFF: { + loptions->SetRateLimiterMode(tc::RateLimitMode::RL_OFF); + break; + } + default: { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string("unknown rate limit mode '" + std::to_string(mode) + "'") + .c_str()); + } + } + + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsAddRateLimiterResource( + TRITONSERVER_ServerOptions* options, const char* name, const size_t count, + const int device) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + return loptions->AddRateLimiterResource(name, count, device); +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetPinnedMemoryPoolByteSize( + TRITONSERVER_ServerOptions* options, uint64_t size) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetPinnedMemoryPoolByteSize(size); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetCudaMemoryPoolByteSize( + TRITONSERVER_ServerOptions* options, int gpu_device, uint64_t size) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetCudaMemoryPoolByteSize(gpu_device, size); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetResponseCacheByteSize( + TRITONSERVER_ServerOptions* options, uint64_t size) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetResponseCacheByteSize(size); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetMinSupportedComputeCapability( + TRITONSERVER_ServerOptions* options, double cc) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetMinSupportedComputeCapability(cc); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetStrictReadiness( + TRITONSERVER_ServerOptions* options, bool strict) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetStrictReadiness(strict); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetExitTimeout( + TRITONSERVER_ServerOptions* options, unsigned int timeout) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetExitTimeout(timeout); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetBufferManagerThreadCount( + TRITONSERVER_ServerOptions* options, unsigned int thread_count) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetBufferManagerThreadCount(thread_count); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetModelLoadThreadCount( + TRITONSERVER_ServerOptions* options, unsigned int thread_count) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetModelLoadThreadCount(thread_count); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetLogFile( + TRITONSERVER_ServerOptions* options, const char* file) +{ +#ifdef TRITON_ENABLE_LOGGING + std::string out_file; + if (file != nullptr) { + out_file = std::string(file); + } + const std::string& error = LOG_SET_OUT_FILE(out_file); + if (!error.empty()) { + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, (error).c_str()); + } + return nullptr; // Success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "logging not supported"); +#endif // TRITON_ENABLE_LOGGING +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetLogInfo( + TRITONSERVER_ServerOptions* options, bool log) +{ +#ifdef TRITON_ENABLE_LOGGING + // Logging is global for now... + LOG_ENABLE_INFO(log); + return nullptr; // Success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "logging not supported"); +#endif // TRITON_ENABLE_LOGGING +} + +// Enable or disable warning level logging. +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetLogWarn( + TRITONSERVER_ServerOptions* options, bool log) +{ +#ifdef TRITON_ENABLE_LOGGING + // Logging is global for now... + LOG_ENABLE_WARNING(log); + return nullptr; // Success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "logging not supported"); +#endif // TRITON_ENABLE_LOGGING +} + +// Enable or disable error level logging. +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetLogError( + TRITONSERVER_ServerOptions* options, bool log) +{ +#ifdef TRITON_ENABLE_LOGGING + // Logging is global for now... + LOG_ENABLE_ERROR(log); + return nullptr; // Success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "logging not supported"); +#endif // TRITON_ENABLE_LOGGING +} + +// Set verbose logging level. Level zero disables verbose logging. +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetLogVerbose( + TRITONSERVER_ServerOptions* options, int level) +{ +#ifdef TRITON_ENABLE_LOGGING + // Logging is global for now... + LOG_SET_VERBOSE(level); +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "logging not supported"); +#endif // TRITON_ENABLE_LOGGING + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetLogFormat( + TRITONSERVER_ServerOptions* options, const TRITONSERVER_LogFormat format) +{ +#ifdef TRITON_ENABLE_LOGGING + // Logging is global for now... + switch (format) { + case TRITONSERVER_LOG_DEFAULT: + LOG_SET_FORMAT(triton::common::Logger::Format::kDEFAULT); + break; + case TRITONSERVER_LOG_ISO8601: + LOG_SET_FORMAT(triton::common::Logger::Format::kISO8601); + break; + } +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "logging not supported"); +#endif // TRITON_ENABLE_LOGGING + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetMetrics( + TRITONSERVER_ServerOptions* options, bool metrics) +{ +#ifdef TRITON_ENABLE_METRICS + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetMetrics(metrics); + return nullptr; // Success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "metrics not supported"); +#endif // TRITON_ENABLE_METRICS +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetGpuMetrics( + TRITONSERVER_ServerOptions* options, bool gpu_metrics) +{ +#ifdef TRITON_ENABLE_METRICS + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetGpuMetrics(gpu_metrics); + return nullptr; // Success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "metrics not supported"); +#endif // TRITON_ENABLE_METRICS +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetCpuMetrics( + TRITONSERVER_ServerOptions* options, bool cpu_metrics) +{ +#ifdef TRITON_ENABLE_METRICS + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetCpuMetrics(cpu_metrics); + return nullptr; // Success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "metrics not supported"); +#endif // TRITON_ENABLE_METRICS +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetMetricsInterval( + TRITONSERVER_ServerOptions* options, uint64_t metrics_interval_ms) +{ +#ifdef TRITON_ENABLE_METRICS + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetMetricsInterval(metrics_interval_ms); + return nullptr; // Success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "metrics not supported"); +#endif // TRITON_ENABLE_METRICS +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetBackendDirectory( + TRITONSERVER_ServerOptions* options, const char* backend_dir) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetBackendDir(backend_dir); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetRepoAgentDirectory( + TRITONSERVER_ServerOptions* options, const char* repoagent_dir) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + loptions->SetRepoAgentDir(repoagent_dir); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetModelLoadDeviceLimit( + TRITONSERVER_ServerOptions* options, + const TRITONSERVER_InstanceGroupKind kind, const int device_id, + const double fraction) +{ + if (device_id < 0) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("expects device ID >= 0, got ") + + std::to_string(device_id)) + .c_str()); + } else if ((fraction < 0.0) || (fraction > 1.0)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("expects limit fraction to be in range [0.0, 1.0], got ") + + std::to_string(fraction)) + .c_str()); + } + + TritonServerOptions* loptions = + reinterpret_cast(options); + switch (kind) { + case TRITONSERVER_INSTANCEGROUPKIND_GPU: { + static std::string key_prefix = "model-load-gpu-limit-device-"; + return loptions->AddBackendConfig( + "", key_prefix + std::to_string(device_id), std::to_string(fraction)); + } + default: + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("given device kind is not supported, got: ") + + TRITONSERVER_InstanceGroupKindString(kind)) + .c_str()); + } +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetBackendConfig( + TRITONSERVER_ServerOptions* options, const char* backend_name, + const char* setting, const char* value) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + return loptions->AddBackendConfig(backend_name, setting, value); +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerOptionsSetHostPolicy( + TRITONSERVER_ServerOptions* options, const char* policy_name, + const char* setting, const char* value) +{ + TritonServerOptions* loptions = + reinterpret_cast(options); + return loptions->SetHostPolicy(policy_name, setting, value); +} + +// +// TRITONSERVER_InferenceRequest +// +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestNew( + TRITONSERVER_InferenceRequest** inference_request, + TRITONSERVER_Server* server, const char* model_name, + const int64_t model_version) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + + std::shared_ptr model; + RETURN_IF_STATUS_ERROR(lserver->GetModel(model_name, model_version, &model)); + + *inference_request = reinterpret_cast( + new tc::InferenceRequest(model, model_version)); + + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestDelete( + TRITONSERVER_InferenceRequest* inference_request) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + delete lrequest; + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestId( + TRITONSERVER_InferenceRequest* inference_request, const char** id) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + *id = lrequest->Id().c_str(); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestSetId( + TRITONSERVER_InferenceRequest* inference_request, const char* id) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + lrequest->SetId(id); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestFlags( + TRITONSERVER_InferenceRequest* inference_request, uint32_t* flags) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + *flags = lrequest->Flags(); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestSetFlags( + TRITONSERVER_InferenceRequest* inference_request, uint32_t flags) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + lrequest->SetFlags(flags); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestCorrelationId( + TRITONSERVER_InferenceRequest* inference_request, uint64_t* correlation_id) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + const tc::InferenceRequest::SequenceId& corr_id = lrequest->CorrelationId(); + if (corr_id.Type() != tc::InferenceRequest::SequenceId::DataType::UINT64) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string("given request's correlation id is not an unsigned int") + .c_str()); + } + *correlation_id = corr_id.UnsignedIntValue(); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestCorrelationIdString( + TRITONSERVER_InferenceRequest* inference_request, + const char** correlation_id) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + const tc::InferenceRequest::SequenceId& corr_id = lrequest->CorrelationId(); + if (corr_id.Type() != tc::InferenceRequest::SequenceId::DataType::STRING) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string("given request's correlation id is not a string").c_str()); + } + *correlation_id = corr_id.StringValue().c_str(); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestSetCorrelationId( + TRITONSERVER_InferenceRequest* inference_request, uint64_t correlation_id) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + lrequest->SetCorrelationId(tc::InferenceRequest::SequenceId(correlation_id)); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestSetCorrelationIdString( + TRITONSERVER_InferenceRequest* inference_request, + const char* correlation_id) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + if (std::string(correlation_id).length() > 128) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + std::string( + "string correlation ID cannot be longer than 128 characters") + .c_str()); + } + lrequest->SetCorrelationId(tc::InferenceRequest::SequenceId(correlation_id)); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestPriority( + TRITONSERVER_InferenceRequest* inference_request, uint32_t* priority) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + *priority = lrequest->Priority(); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestSetPriority( + TRITONSERVER_InferenceRequest* inference_request, uint32_t priority) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + lrequest->SetPriority(priority); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestTimeoutMicroseconds( + TRITONSERVER_InferenceRequest* inference_request, uint64_t* timeout_us) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + *timeout_us = lrequest->TimeoutMicroseconds(); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestSetTimeoutMicroseconds( + TRITONSERVER_InferenceRequest* inference_request, uint64_t timeout_us) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + lrequest->SetTimeoutMicroseconds(timeout_us); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestAddInput( + TRITONSERVER_InferenceRequest* inference_request, const char* name, + const TRITONSERVER_DataType datatype, const int64_t* shape, + uint64_t dim_count) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + RETURN_IF_STATUS_ERROR(lrequest->AddOriginalInput( + name, tc::TritonToDataType(datatype), shape, dim_count)); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestAddRawInput( + TRITONSERVER_InferenceRequest* inference_request, const char* name) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + RETURN_IF_STATUS_ERROR(lrequest->AddRawInput(name)); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestRemoveInput( + TRITONSERVER_InferenceRequest* inference_request, const char* name) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + RETURN_IF_STATUS_ERROR(lrequest->RemoveOriginalInput(name)); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestRemoveAllInputs( + TRITONSERVER_InferenceRequest* inference_request) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + RETURN_IF_STATUS_ERROR(lrequest->RemoveAllOriginalInputs()); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestAppendInputData( + TRITONSERVER_InferenceRequest* inference_request, const char* name, + const void* base, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + + tc::InferenceRequest::Input* input; + RETURN_IF_STATUS_ERROR(lrequest->MutableOriginalInput(name, &input)); + RETURN_IF_STATUS_ERROR( + input->AppendData(base, byte_size, memory_type, memory_type_id)); + + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestAppendInputDataWithHostPolicy( + TRITONSERVER_InferenceRequest* inference_request, const char* name, + const void* base, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id, const char* host_policy_name) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + + tc::InferenceRequest::Input* input; + RETURN_IF_STATUS_ERROR(lrequest->MutableOriginalInput(name, &input)); + RETURN_IF_STATUS_ERROR(input->AppendDataWithHostPolicy( + base, byte_size, memory_type, memory_type_id, host_policy_name)); + + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestAppendInputDataWithBufferAttributes( + TRITONSERVER_InferenceRequest* inference_request, const char* name, + const void* base, TRITONSERVER_BufferAttributes* buffer_attributes) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + tc::BufferAttributes* lbuffer_attributes = + reinterpret_cast(buffer_attributes); + + tc::InferenceRequest::Input* input; + RETURN_IF_STATUS_ERROR(lrequest->MutableOriginalInput(name, &input)); + RETURN_IF_STATUS_ERROR( + input->AppendDataWithBufferAttributes(base, lbuffer_attributes)); + + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestRemoveAllInputData( + TRITONSERVER_InferenceRequest* inference_request, const char* name) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + + tc::InferenceRequest::Input* input; + RETURN_IF_STATUS_ERROR(lrequest->MutableOriginalInput(name, &input)); + RETURN_IF_STATUS_ERROR(input->RemoveAllData()); + + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestAddRequestedOutput( + TRITONSERVER_InferenceRequest* inference_request, const char* name) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + RETURN_IF_STATUS_ERROR(lrequest->AddOriginalRequestedOutput(name)); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestRemoveRequestedOutput( + TRITONSERVER_InferenceRequest* inference_request, const char* name) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + RETURN_IF_STATUS_ERROR(lrequest->RemoveOriginalRequestedOutput(name)); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestRemoveAllRequestedOutputs( + TRITONSERVER_InferenceRequest* inference_request) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + RETURN_IF_STATUS_ERROR(lrequest->RemoveAllOriginalRequestedOutputs()); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestSetReleaseCallback( + TRITONSERVER_InferenceRequest* inference_request, + TRITONSERVER_InferenceRequestReleaseFn_t request_release_fn, + void* request_release_userp) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + RETURN_IF_STATUS_ERROR( + lrequest->SetReleaseCallback(request_release_fn, request_release_userp)); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceRequestSetResponseCallback( + TRITONSERVER_InferenceRequest* inference_request, + TRITONSERVER_ResponseAllocator* response_allocator, + void* response_allocator_userp, + TRITONSERVER_InferenceResponseCompleteFn_t response_fn, + void* response_userp) +{ + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + tc::ResponseAllocator* lallocator = + reinterpret_cast(response_allocator); + RETURN_IF_STATUS_ERROR(lrequest->SetResponseCallback( + lallocator, response_allocator_userp, response_fn, response_userp)); + return nullptr; // Success +} + +// +// TRITONSERVER_InferenceResponse +// +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceResponseDelete( + TRITONSERVER_InferenceResponse* inference_response) +{ + tc::InferenceResponse* lresponse = + reinterpret_cast(inference_response); + delete lresponse; + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceResponseError( + TRITONSERVER_InferenceResponse* inference_response) +{ + tc::InferenceResponse* lresponse = + reinterpret_cast(inference_response); + RETURN_IF_STATUS_ERROR(lresponse->ResponseStatus()); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceResponseModel( + TRITONSERVER_InferenceResponse* inference_response, const char** model_name, + int64_t* model_version) +{ + tc::InferenceResponse* lresponse = + reinterpret_cast(inference_response); + + *model_name = lresponse->ModelName().c_str(); + *model_version = lresponse->ActualModelVersion(); + + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceResponseId( + TRITONSERVER_InferenceResponse* inference_response, const char** request_id) +{ + tc::InferenceResponse* lresponse = + reinterpret_cast(inference_response); + + *request_id = lresponse->Id().c_str(); + + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceResponseParameterCount( + TRITONSERVER_InferenceResponse* inference_response, uint32_t* count) +{ + tc::InferenceResponse* lresponse = + reinterpret_cast(inference_response); + + const auto& parameters = lresponse->Parameters(); + *count = parameters.size(); + + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceResponseParameter( + TRITONSERVER_InferenceResponse* inference_response, const uint32_t index, + const char** name, TRITONSERVER_ParameterType* type, const void** vvalue) +{ + tc::InferenceResponse* lresponse = + reinterpret_cast(inference_response); + + const auto& parameters = lresponse->Parameters(); + if (index >= parameters.size()) { + return TritonServerError::Create( + TRITONSERVER_ERROR_INVALID_ARG, + "out of bounds index " + std::to_string(index) + + std::string(": response has ") + std::to_string(parameters.size()) + + " parameters"); + } + + const tc::InferenceParameter& param = parameters[index]; + + *name = param.Name().c_str(); + *type = param.Type(); + *vvalue = param.ValuePointer(); + + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceResponseOutputCount( + TRITONSERVER_InferenceResponse* inference_response, uint32_t* count) +{ + tc::InferenceResponse* lresponse = + reinterpret_cast(inference_response); + + const auto& outputs = lresponse->Outputs(); + *count = outputs.size(); + + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceResponseOutput( + TRITONSERVER_InferenceResponse* inference_response, const uint32_t index, + const char** name, TRITONSERVER_DataType* datatype, const int64_t** shape, + uint64_t* dim_count, const void** base, size_t* byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id, void** userp) +{ + tc::InferenceResponse* lresponse = + reinterpret_cast(inference_response); + + const auto& outputs = lresponse->Outputs(); + if (index >= outputs.size()) { + return TritonServerError::Create( + TRITONSERVER_ERROR_INVALID_ARG, + "out of bounds index " + std::to_string(index) + + std::string(": response has ") + std::to_string(outputs.size()) + + " outputs"); + } + + const tc::InferenceResponse::Output& output = outputs[index]; + + *name = output.Name().c_str(); + *datatype = tc::DataTypeToTriton(output.DType()); + + const std::vector& oshape = output.Shape(); + *shape = &oshape[0]; + *dim_count = oshape.size(); + + RETURN_IF_STATUS_ERROR( + output.DataBuffer(base, byte_size, memory_type, memory_type_id, userp)); + + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_InferenceResponseOutputClassificationLabel( + TRITONSERVER_InferenceResponse* inference_response, const uint32_t index, + const size_t class_index, const char** label) +{ + tc::InferenceResponse* lresponse = + reinterpret_cast(inference_response); + + const auto& outputs = lresponse->Outputs(); + if (index >= outputs.size()) { + return TritonServerError::Create( + TRITONSERVER_ERROR_INVALID_ARG, + "out of bounds index " + std::to_string(index) + + std::string(": response has ") + std::to_string(outputs.size()) + + " outputs"); + } + + const tc::InferenceResponse::Output& output = outputs[index]; + RETURN_IF_STATUS_ERROR( + lresponse->ClassificationLabel(output, class_index, label)); + + return nullptr; // Success +} + +// +// TRITONSERVER_BufferAttributes +// +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_BufferAttributesNew( + TRITONSERVER_BufferAttributes** buffer_attributes) +{ + tc::BufferAttributes* lbuffer_attributes = new tc::BufferAttributes(); + *buffer_attributes = + reinterpret_cast(lbuffer_attributes); + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_BufferAttributesDelete( + TRITONSERVER_BufferAttributes* buffer_attributes) +{ + tc::BufferAttributes* lbuffer_attributes = + reinterpret_cast(buffer_attributes); + delete lbuffer_attributes; + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_BufferAttributesSetMemoryTypeId( + TRITONSERVER_BufferAttributes* buffer_attributes, int64_t memory_type_id) +{ + tc::BufferAttributes* lbuffer_attributes = + reinterpret_cast(buffer_attributes); + lbuffer_attributes->SetMemoryTypeId(memory_type_id); + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_BufferAttributesSetMemoryType( + TRITONSERVER_BufferAttributes* buffer_attributes, + TRITONSERVER_MemoryType memory_type) +{ + tc::BufferAttributes* lbuffer_attributes = + reinterpret_cast(buffer_attributes); + lbuffer_attributes->SetMemoryType(memory_type); + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_BufferAttributesSetCudaIpcHandle( + TRITONSERVER_BufferAttributes* buffer_attributes, void* cuda_ipc_handle) +{ + tc::BufferAttributes* lbuffer_attributes = + reinterpret_cast(buffer_attributes); + lbuffer_attributes->SetCudaIpcHandle(cuda_ipc_handle); + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_BufferAttributesSetByteSize( + TRITONSERVER_BufferAttributes* buffer_attributes, size_t byte_size) +{ + tc::BufferAttributes* lbuffer_attributes = + reinterpret_cast(buffer_attributes); + lbuffer_attributes->SetByteSize(byte_size); + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_BufferAttributesMemoryTypeId( + TRITONSERVER_BufferAttributes* buffer_attributes, int64_t* memory_type_id) +{ + tc::BufferAttributes* lbuffer_attributes = + reinterpret_cast(buffer_attributes); + *memory_type_id = lbuffer_attributes->MemoryTypeId(); + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_BufferAttributesMemoryType( + TRITONSERVER_BufferAttributes* buffer_attributes, + TRITONSERVER_MemoryType* memory_type) +{ + tc::BufferAttributes* lbuffer_attributes = + reinterpret_cast(buffer_attributes); + *memory_type = lbuffer_attributes->MemoryType(); + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_BufferAttributesCudaIpcHandle( + TRITONSERVER_BufferAttributes* buffer_attributes, void** cuda_ipc_handle) +{ + tc::BufferAttributes* lbuffer_attributes = + reinterpret_cast(buffer_attributes); + *cuda_ipc_handle = lbuffer_attributes->CudaIpcHandle(); + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_BufferAttributesByteSize( + TRITONSERVER_BufferAttributes* buffer_attributes, size_t* byte_size) +{ + tc::BufferAttributes* lbuffer_attributes = + reinterpret_cast(buffer_attributes); + *byte_size = lbuffer_attributes->ByteSize(); + + return nullptr; // success +} + +// +// TRITONSERVER_Server +// +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerNew( + TRITONSERVER_Server** server, TRITONSERVER_ServerOptions* options) +{ + tc::InferenceServer* lserver = new tc::InferenceServer(); + TritonServerOptions* loptions = + reinterpret_cast(options); + + NVTX_INITIALIZE; + +#ifdef TRITON_ENABLE_METRICS + // NOTE: Metrics must be enabled before backends are setup + if (loptions->Metrics()) { + tc::Metrics::EnableMetrics(); + tc::Metrics::SetMetricsInterval(loptions->MetricsInterval()); + } +#endif // TRITON_ENABLE_METRICS + + lserver->SetId(loptions->ServerId()); + lserver->SetModelRepositoryPaths(loptions->ModelRepositoryPaths()); + lserver->SetModelControlMode(loptions->ModelControlMode()); + lserver->SetStartupModels(loptions->StartupModels()); + bool strict_model_config = loptions->StrictModelConfig(); + lserver->SetStrictModelConfigEnabled(strict_model_config); + lserver->SetRateLimiterMode(loptions->RateLimiterMode()); + lserver->SetRateLimiterResources(loptions->RateLimiterResources()); + lserver->SetPinnedMemoryPoolByteSize(loptions->PinnedMemoryPoolByteSize()); + lserver->SetResponseCacheByteSize(loptions->ResponseCacheByteSize()); + lserver->SetCudaMemoryPoolByteSize(loptions->CudaMemoryPoolByteSize()); + double min_compute_capability = loptions->MinSupportedComputeCapability(); + lserver->SetMinSupportedComputeCapability(min_compute_capability); + lserver->SetStrictReadinessEnabled(loptions->StrictReadiness()); + lserver->SetExitTimeoutSeconds(loptions->ExitTimeout()); + lserver->SetHostPolicyCmdlineConfig(loptions->HostPolicyCmdlineConfigMap()); + lserver->SetRepoAgentDir(loptions->RepoAgentDir()); + lserver->SetBufferManagerThreadCount(loptions->BufferManagerThreadCount()); + lserver->SetModelLoadThreadCount(loptions->ModelLoadThreadCount()); + + // SetBackendCmdlineConfig must be called after all AddBackendConfig calls + // have completed. + // Note that the auto complete config condition is reverted + // due to setting name being different + loptions->AddBackendConfig( + std::string(), "auto-complete-config", + strict_model_config ? "false" : "true"); + loptions->AddBackendConfig( + std::string(), "min-compute-capability", + std::to_string(min_compute_capability)); + loptions->AddBackendConfig( + std::string(), "backend-directory", loptions->BackendDir()); + lserver->SetBackendCmdlineConfig(loptions->BackendCmdlineConfigMap()); + + // Initialize server + tc::Status status = lserver->Init(); + +#ifdef TRITON_ENABLE_METRICS + if (loptions->Metrics() && lserver->ResponseCacheEnabled()) { + // NOTE: Cache metrics must be enabled after cache initialized in + // server->Init() + tc::Metrics::EnableCacheMetrics(lserver->GetResponseCache()); + } +#ifdef TRITON_ENABLE_METRICS_GPU + if (loptions->Metrics() && loptions->GpuMetrics()) { + tc::Metrics::EnableGPUMetrics(); + } +#endif // TRITON_ENABLE_METRICS_GPU + +#ifdef TRITON_ENABLE_METRICS_CPU + if (loptions->Metrics() && loptions->CpuMetrics()) { + tc::Metrics::EnableCpuMetrics(); + } +#endif // TRITON_ENABLE_METRICS_CPU + + const bool poll_metrics = + (lserver->ResponseCacheEnabled() || loptions->GpuMetrics() || + loptions->CpuMetrics()); + if (loptions->Metrics() && poll_metrics) { + // Start thread to poll enabled metrics periodically + tc::Metrics::StartPollingThreadSingleton(lserver->GetResponseCache()); + } +#endif // TRITON_ENABLE_METRICS + + + // Setup tritonserver options table + std::vector options_headers; + options_headers.emplace_back("Option"); + options_headers.emplace_back("Value"); + + triton::common::TablePrinter options_table(options_headers); + options_table.InsertRow(std::vector{"server_id", lserver->Id()}); + options_table.InsertRow( + std::vector{"server_version", lserver->Version()}); + + auto extensions = lserver->Extensions(); + std::string exts; + for (const auto& ext : extensions) { + exts.append(ext); + exts.append(" "); + } + + // Remove the trailing space + if (exts.size() > 0) + exts.pop_back(); + + options_table.InsertRow(std::vector{"server_extensions", exts}); + + size_t i = 0; + for (const auto& model_repository_path : lserver->ModelRepositoryPaths()) { + options_table.InsertRow(std::vector{ + "model_repository_path[" + std::to_string(i) + "]", + model_repository_path}); + ++i; + } + + std::string model_control_mode; + auto control_mode = lserver->GetModelControlMode(); + switch (control_mode) { + case tc::ModelControlMode::MODE_NONE: { + model_control_mode = "MODE_NONE"; + break; + } + case tc::ModelControlMode::MODE_POLL: { + model_control_mode = "MODE_POLL"; + break; + } + case tc::ModelControlMode::MODE_EXPLICIT: { + model_control_mode = "MODE_EXPLICIT"; + break; + } + default: { + model_control_mode = ""; + } + } + options_table.InsertRow( + std::vector{"model_control_mode", model_control_mode}); + + i = 0; + for (const auto& startup_model : lserver->StartupModels()) { + options_table.InsertRow(std::vector{ + "startup_models_" + std::to_string(i), startup_model}); + ++i; + } + options_table.InsertRow(std::vector{ + "strict_model_config", + std::to_string(lserver->StrictModelConfigEnabled())}); + std::string rate_limit = RateLimitModeToString(lserver->RateLimiterMode()); + options_table.InsertRow(std::vector{"rate_limit", rate_limit}); + i = 0; + for (const auto& device_resources : lserver->RateLimiterResources()) { + for (const auto& resource : device_resources.second) { + options_table.InsertRow(std::vector{ + "rate_limit_resource[" + std::to_string(i) + "]", + ResourceString( + resource.first, resource.second, device_resources.first)}); + ++i; + } + } + options_table.InsertRow(std::vector{ + "pinned_memory_pool_byte_size", + std::to_string(lserver->PinnedMemoryPoolByteSize())}); + for (const auto& cuda_memory_pool : lserver->CudaMemoryPoolByteSize()) { + options_table.InsertRow(std::vector{ + "cuda_memory_pool_byte_size{" + std::to_string(cuda_memory_pool.first) + + "}", + std::to_string(cuda_memory_pool.second)}); + } + options_table.InsertRow(std::vector{ + "response_cache_byte_size", + std::to_string(lserver->ResponseCacheByteSize())}); + + std::stringstream compute_capability_ss; + compute_capability_ss.setf(std::ios::fixed); + compute_capability_ss.precision(1); + compute_capability_ss << lserver->MinSupportedComputeCapability(); + options_table.InsertRow(std::vector{ + "min_supported_compute_capability", compute_capability_ss.str()}); + options_table.InsertRow(std::vector{ + "strict_readiness", std::to_string(lserver->StrictReadinessEnabled())}); + options_table.InsertRow(std::vector{ + "exit_timeout", std::to_string(lserver->ExitTimeoutSeconds())}); + + std::string options_table_string = options_table.PrintTable(); + LOG_INFO << options_table_string; + + if (!status.IsOk()) { + if (loptions->ExitOnError()) { + lserver->Stop(true /* force */); + delete lserver; + RETURN_IF_STATUS_ERROR(status); + } + + LOG_ERROR << status.AsString(); + } + + *server = reinterpret_cast(lserver); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerDelete(TRITONSERVER_Server* server) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + if (lserver != nullptr) { + RETURN_IF_STATUS_ERROR(lserver->Stop()); + } + delete lserver; + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerStop(TRITONSERVER_Server* server) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + if (lserver != nullptr) { + RETURN_IF_STATUS_ERROR(lserver->Stop()); + } + return nullptr; // Success +} + +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerRegisterModelRepository( + TRITONSERVER_Server* server, const char* repository_path, + const TRITONSERVER_Parameter** name_mapping, const uint32_t mapping_count) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + if ((name_mapping == nullptr) && (mapping_count != 0)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "model mappings are not provided while mapping count is non-zero"); + } + + std::unordered_map model_mapping; + for (size_t i = 0; i < mapping_count; ++i) { + auto mapping = + reinterpret_cast(name_mapping[i]); + auto subdir = mapping->Name(); + + if (mapping->Type() != TRITONSERVER_PARAMETER_STRING) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "Mapped model name must be a string, found " + "another type for " + + subdir) + .c_str()); + } + + auto model_name = + std::string(reinterpret_cast(mapping->ValuePointer())); + + if (!(model_mapping.emplace(model_name, subdir).second)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("failed to register '") + repository_path + + "', there is a conflicting mapping for '" + std::string(model_name) + + "'") + .c_str()); + } + } + RETURN_IF_STATUS_ERROR( + lserver->RegisterModelRepository(repository_path, model_mapping)); + return nullptr; // Success +} + +TRITONSERVER_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerUnregisterModelRepository( + TRITONSERVER_Server* server, const char* repository_path) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + RETURN_IF_STATUS_ERROR(lserver->UnregisterModelRepository(repository_path)); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerPollModelRepository(TRITONSERVER_Server* server) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + RETURN_IF_STATUS_ERROR(lserver->PollModelRepository()); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerIsLive(TRITONSERVER_Server* server, bool* live) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + + RETURN_IF_STATUS_ERROR(lserver->IsLive(live)); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerIsReady(TRITONSERVER_Server* server, bool* ready) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + + RETURN_IF_STATUS_ERROR(lserver->IsReady(ready)); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerModelIsReady( + TRITONSERVER_Server* server, const char* model_name, + const int64_t model_version, bool* ready) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + + RETURN_IF_STATUS_ERROR( + lserver->ModelIsReady(model_name, model_version, ready)); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerModelBatchProperties( + TRITONSERVER_Server* server, const char* model_name, + const int64_t model_version, uint32_t* flags, void** voidp) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + + if (voidp != nullptr) { + *voidp = nullptr; + } + + std::shared_ptr model; + RETURN_IF_STATUS_ERROR(lserver->GetModel(model_name, model_version, &model)); + + if (model->Config().max_batch_size() > 0) { + *flags = TRITONSERVER_BATCH_FIRST_DIM; + } else { + *flags = TRITONSERVER_BATCH_UNKNOWN; + } + + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerModelTransactionProperties( + TRITONSERVER_Server* server, const char* model_name, + const int64_t model_version, uint32_t* txn_flags, void** voidp) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + + if (voidp != nullptr) { + *voidp = nullptr; + } + + *txn_flags = 0; + + std::shared_ptr model; + RETURN_IF_STATUS_ERROR(lserver->GetModel(model_name, model_version, &model)); + + if (model->Config().model_transaction_policy().decoupled()) { + *txn_flags = TRITONSERVER_TXN_DECOUPLED; + } else { + *txn_flags = TRITONSERVER_TXN_ONE_TO_ONE; + } + + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerMetadata( + TRITONSERVER_Server* server, TRITONSERVER_Message** server_metadata) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + + triton::common::TritonJson::Value metadata( + triton::common::TritonJson::ValueType::OBJECT); + + // Just store string reference in JSON object since it will be + // serialized to another buffer before lserver->Id() or + // lserver->Version() lifetime ends. + RETURN_IF_STATUS_ERROR(metadata.AddStringRef("name", lserver->Id().c_str())); + RETURN_IF_STATUS_ERROR( + metadata.AddStringRef("version", lserver->Version().c_str())); + + triton::common::TritonJson::Value extensions( + metadata, triton::common::TritonJson::ValueType::ARRAY); + const std::vector& exts = lserver->Extensions(); + for (const auto ext : exts) { + RETURN_IF_STATUS_ERROR(extensions.AppendStringRef(ext)); + } + + RETURN_IF_STATUS_ERROR(metadata.Add("extensions", std::move(extensions))); + + *server_metadata = reinterpret_cast( + new tc::TritonServerMessage(metadata)); + return nullptr; // Success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerModelMetadata( + TRITONSERVER_Server* server, const char* model_name, + const int64_t model_version, TRITONSERVER_Message** model_metadata) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + + std::shared_ptr model; + RETURN_IF_STATUS_ERROR(lserver->GetModel(model_name, model_version, &model)); + + std::vector ready_versions; + RETURN_IF_STATUS_ERROR( + lserver->ModelReadyVersions(model_name, &ready_versions)); + + triton::common::TritonJson::Value metadata( + triton::common::TritonJson::ValueType::OBJECT); + + // Can use string ref in this function even though model can be + // unloaded and config becomes invalid, because TritonServeMessage + // serializes the json when it is constructed below. + RETURN_IF_STATUS_ERROR(metadata.AddStringRef("name", model_name)); + + triton::common::TritonJson::Value versions( + metadata, triton::common::TritonJson::ValueType::ARRAY); + if (model_version != -1) { + RETURN_IF_STATUS_ERROR( + versions.AppendString(std::move(std::to_string(model_version)))); + } else { + for (const auto v : ready_versions) { + RETURN_IF_STATUS_ERROR( + versions.AppendString(std::move(std::to_string(v)))); + } + } + + RETURN_IF_STATUS_ERROR(metadata.Add("versions", std::move(versions))); + + const auto& model_config = model->Config(); + if (!model_config.platform().empty()) { + RETURN_IF_STATUS_ERROR( + metadata.AddStringRef("platform", model_config.platform().c_str())); + } else { + RETURN_IF_STATUS_ERROR( + metadata.AddStringRef("platform", model_config.backend().c_str())); + } + + triton::common::TritonJson::Value inputs( + metadata, triton::common::TritonJson::ValueType::ARRAY); + for (const auto& io : model_config.input()) { + triton::common::TritonJson::Value io_metadata( + metadata, triton::common::TritonJson::ValueType::OBJECT); + RETURN_IF_STATUS_ERROR(io_metadata.AddStringRef("name", io.name().c_str())); + RETURN_IF_STATUS_ERROR(io_metadata.AddStringRef( + "datatype", triton::common::DataTypeToProtocolString(io.data_type()))); + + // Input shape. If the model supports batching then must include + // '-1' for the batch dimension. + triton::common::TritonJson::Value io_metadata_shape( + metadata, triton::common::TritonJson::ValueType::ARRAY); + if (model_config.max_batch_size() >= 1) { + RETURN_IF_STATUS_ERROR(io_metadata_shape.AppendInt(-1)); + } + for (const auto d : io.dims()) { + RETURN_IF_STATUS_ERROR(io_metadata_shape.AppendInt(d)); + } + RETURN_IF_STATUS_ERROR( + io_metadata.Add("shape", std::move(io_metadata_shape))); + + RETURN_IF_STATUS_ERROR(inputs.Append(std::move(io_metadata))); + } + RETURN_IF_STATUS_ERROR(metadata.Add("inputs", std::move(inputs))); + + triton::common::TritonJson::Value outputs( + metadata, triton::common::TritonJson::ValueType::ARRAY); + for (const auto& io : model_config.output()) { + triton::common::TritonJson::Value io_metadata( + metadata, triton::common::TritonJson::ValueType::OBJECT); + RETURN_IF_STATUS_ERROR(io_metadata.AddStringRef("name", io.name().c_str())); + RETURN_IF_STATUS_ERROR(io_metadata.AddStringRef( + "datatype", triton::common::DataTypeToProtocolString(io.data_type()))); + + // Output shape. If the model supports batching then must include + // '-1' for the batch dimension. + triton::common::TritonJson::Value io_metadata_shape( + metadata, triton::common::TritonJson::ValueType::ARRAY); + if (model_config.max_batch_size() >= 1) { + RETURN_IF_STATUS_ERROR(io_metadata_shape.AppendInt(-1)); + } + for (const auto d : io.dims()) { + RETURN_IF_STATUS_ERROR(io_metadata_shape.AppendInt(d)); + } + RETURN_IF_STATUS_ERROR( + io_metadata.Add("shape", std::move(io_metadata_shape))); + + RETURN_IF_STATUS_ERROR(outputs.Append(std::move(io_metadata))); + } + RETURN_IF_STATUS_ERROR(metadata.Add("outputs", std::move(outputs))); + + *model_metadata = reinterpret_cast( + new tc::TritonServerMessage(metadata)); + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerModelStatistics( + TRITONSERVER_Server* server, const char* model_name, + const int64_t model_version, TRITONSERVER_Message** model_stats) +{ +#ifndef TRITON_ENABLE_STATS + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "statistics not supported"); +#else + + tc::InferenceServer* lserver = reinterpret_cast(server); + + auto model_name_string = std::string(model_name); + std::map> ready_model_versions; + if (model_name_string.empty()) { + RETURN_IF_STATUS_ERROR(lserver->ModelReadyVersions(&ready_model_versions)); + } else { + std::vector ready_versions; + RETURN_IF_STATUS_ERROR( + lserver->ModelReadyVersions(model_name_string, &ready_versions)); + if (ready_versions.empty()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "requested model '" + model_name_string + "' is not available") + .c_str()); + } + + if (model_version == -1) { + ready_model_versions.emplace( + model_name_string, std::move(ready_versions)); + } else { + bool found = false; + for (const auto v : ready_versions) { + if (v == model_version) { + found = true; + break; + } + } + if (found) { + ready_model_versions.emplace( + model_name_string, std::vector{model_version}); + } else { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "requested model version is not available for model '" + + model_name_string + "'") + .c_str()); + } + } + } + + // Can use string ref in this function because TritonServeMessage + // serializes the json when it is constructed below. + triton::common::TritonJson::Value metadata( + triton::common::TritonJson::ValueType::OBJECT); + + triton::common::TritonJson::Value model_stats_json( + metadata, triton::common::TritonJson::ValueType::ARRAY); + for (const auto& mv_pair : ready_model_versions) { + for (const auto& version : mv_pair.second) { + std::shared_ptr model; + RETURN_IF_STATUS_ERROR(lserver->GetModel(mv_pair.first, version, &model)); + const auto& infer_stats = model->StatsAggregator().ImmutableInferStats(); + const auto& infer_batch_stats = + model->StatsAggregator().ImmutableInferBatchStats(); + + triton::common::TritonJson::Value inference_stats( + metadata, triton::common::TritonJson::ValueType::OBJECT); + // Compute figures only calculated when not going through cache, so + // subtract cache_hit count from success count. Cache hit count will + // simply be 0 when cache is disabled. + uint64_t compute_count = + infer_stats.success_count_ - infer_stats.cache_hit_count_; + SetDurationStat( + metadata, inference_stats, "success", infer_stats.success_count_, + infer_stats.request_duration_ns_); + SetDurationStat( + metadata, inference_stats, "fail", infer_stats.failure_count_, + infer_stats.failure_duration_ns_); + SetDurationStat( + metadata, inference_stats, "queue", infer_stats.success_count_, + infer_stats.queue_duration_ns_); + SetDurationStat( + metadata, inference_stats, "compute_input", compute_count, + infer_stats.compute_input_duration_ns_); + SetDurationStat( + metadata, inference_stats, "compute_infer", compute_count, + infer_stats.compute_infer_duration_ns_); + SetDurationStat( + metadata, inference_stats, "compute_output", compute_count, + infer_stats.compute_output_duration_ns_); + SetDurationStat( + metadata, inference_stats, "cache_hit", infer_stats.cache_hit_count_, + infer_stats.cache_hit_lookup_duration_ns_); + // NOTE: cache_miss_count_ should equal compute_count if non-zero + SetDurationStat( + metadata, inference_stats, "cache_miss", + infer_stats.cache_miss_count_, + infer_stats.cache_miss_lookup_duration_ns_ + + infer_stats.cache_miss_insertion_duration_ns_); + + triton::common::TritonJson::Value batch_stats( + metadata, triton::common::TritonJson::ValueType::ARRAY); + for (const auto& batch : infer_batch_stats) { + triton::common::TritonJson::Value batch_stat( + metadata, triton::common::TritonJson::ValueType::OBJECT); + RETURN_IF_STATUS_ERROR(batch_stat.AddUInt("batch_size", batch.first)); + SetDurationStat( + metadata, batch_stat, "compute_input", batch.second.count_, + batch.second.compute_input_duration_ns_); + SetDurationStat( + metadata, batch_stat, "compute_infer", batch.second.count_, + batch.second.compute_infer_duration_ns_); + SetDurationStat( + metadata, batch_stat, "compute_output", batch.second.count_, + batch.second.compute_output_duration_ns_); + RETURN_IF_STATUS_ERROR(batch_stats.Append(std::move(batch_stat))); + } + + triton::common::TritonJson::Value model_stat( + metadata, triton::common::TritonJson::ValueType::OBJECT); + RETURN_IF_STATUS_ERROR( + model_stat.AddStringRef("name", mv_pair.first.c_str())); + RETURN_IF_STATUS_ERROR( + model_stat.AddString("version", std::move(std::to_string(version)))); + + RETURN_IF_STATUS_ERROR(model_stat.AddUInt( + "last_inference", model->StatsAggregator().LastInferenceMs())); + RETURN_IF_STATUS_ERROR(model_stat.AddUInt( + "inference_count", model->StatsAggregator().InferenceCount())); + RETURN_IF_STATUS_ERROR(model_stat.AddUInt( + "execution_count", model->StatsAggregator().ExecutionCount())); + + RETURN_IF_STATUS_ERROR( + model_stat.Add("inference_stats", std::move(inference_stats))); + RETURN_IF_STATUS_ERROR( + model_stat.Add("batch_stats", std::move(batch_stats))); + + RETURN_IF_STATUS_ERROR(model_stats_json.Append(std::move(model_stat))); + } + } + + RETURN_IF_STATUS_ERROR( + metadata.Add("model_stats", std::move(model_stats_json))); + *model_stats = reinterpret_cast( + new tc::TritonServerMessage(metadata)); + + return nullptr; // success + +#endif // TRITON_ENABLE_STATS +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerModelConfig( + TRITONSERVER_Server* server, const char* model_name, + const int64_t model_version, const uint32_t config_version, + TRITONSERVER_Message** model_config) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + + std::shared_ptr model; + RETURN_IF_STATUS_ERROR(lserver->GetModel(model_name, model_version, &model)); + + std::string model_config_json; + RETURN_IF_STATUS_ERROR(tc::ModelConfigToJson( + model->Config(), config_version, &model_config_json)); + + *model_config = reinterpret_cast( + new tc::TritonServerMessage(std::move(model_config_json))); + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerModelIndex( + TRITONSERVER_Server* server, uint32_t flags, + TRITONSERVER_Message** repository_index) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + + const bool ready_only = ((flags & TRITONSERVER_INDEX_FLAG_READY) != 0); + + std::vector index; + RETURN_IF_STATUS_ERROR(lserver->RepositoryIndex(ready_only, &index)); + + // Can use string ref in this function because TritonServerMessage + // serializes the json when it is constructed below. + triton::common::TritonJson::Value repository_index_json( + triton::common::TritonJson::ValueType::ARRAY); + + for (const auto& in : index) { + triton::common::TritonJson::Value model_index( + repository_index_json, triton::common::TritonJson::ValueType::OBJECT); + RETURN_IF_STATUS_ERROR(model_index.AddStringRef("name", in.name_.c_str())); + if (!in.name_only_) { + if (in.version_ >= 0) { + RETURN_IF_STATUS_ERROR(model_index.AddString( + "version", std::move(std::to_string(in.version_)))); + } + RETURN_IF_STATUS_ERROR(model_index.AddStringRef( + "state", tc::ModelReadyStateString(in.state_).c_str())); + if (!in.reason_.empty()) { + RETURN_IF_STATUS_ERROR( + model_index.AddStringRef("reason", in.reason_.c_str())); + } + } + + RETURN_IF_STATUS_ERROR( + repository_index_json.Append(std::move(model_index))); + } + + *repository_index = reinterpret_cast( + new tc::TritonServerMessage(repository_index_json)); + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerLoadModel( + TRITONSERVER_Server* server, const char* model_name) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + + RETURN_IF_STATUS_ERROR(lserver->LoadModel({{std::string(model_name), {}}})); + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerLoadModelWithParameters( + TRITONSERVER_Server* server, const char* model_name, + const TRITONSERVER_Parameter** parameters, const uint64_t parameter_count) +{ + if ((parameters == nullptr) && (parameter_count != 0)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "load parameters are not provided while parameter count is non-zero"); + } + + tc::InferenceServer* lserver = reinterpret_cast(server); + + std::unordered_map> + models; + std::vector mp; + for (size_t i = 0; i < parameter_count; ++i) { + mp.emplace_back( + reinterpret_cast(parameters[i])); + } + models[model_name] = std::move(mp); + RETURN_IF_STATUS_ERROR(lserver->LoadModel(models)); + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerUnloadModel( + TRITONSERVER_Server* server, const char* model_name) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + + RETURN_IF_STATUS_ERROR(lserver->UnloadModel( + std::string(model_name), false /* unload_dependents */)); + + return nullptr; // success +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerUnloadModelAndDependents( + TRITONSERVER_Server* server, const char* model_name) +{ + { + tc::InferenceServer* lserver = + reinterpret_cast(server); + + RETURN_IF_STATUS_ERROR(lserver->UnloadModel( + std::string(model_name), true /* unload_dependents */)); + + return nullptr; // success + } +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerMetrics( + TRITONSERVER_Server* server, TRITONSERVER_Metrics** metrics) +{ +#ifdef TRITON_ENABLE_METRICS + TritonServerMetrics* lmetrics = new TritonServerMetrics(); + *metrics = reinterpret_cast(lmetrics); + return nullptr; // Success +#else + *metrics = nullptr; + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "metrics not supported"); +#endif // TRITON_ENABLE_METRICS +} + +TRITONAPI_DECLSPEC TRITONSERVER_Error* +TRITONSERVER_ServerInferAsync( + TRITONSERVER_Server* server, + TRITONSERVER_InferenceRequest* inference_request, + TRITONSERVER_InferenceTrace* trace) +{ + tc::InferenceServer* lserver = reinterpret_cast(server); + tc::InferenceRequest* lrequest = + reinterpret_cast(inference_request); + + RETURN_IF_STATUS_ERROR(lrequest->PrepareForInference()); + + // Set the trace object in the request so that activity associated + // with the request can be recorded as the request flows through + // Triton. + if (trace != nullptr) { +#ifdef TRITON_ENABLE_TRACING + tc::InferenceTrace* ltrace = reinterpret_cast(trace); + ltrace->SetModelName(lrequest->ModelName()); + ltrace->SetModelVersion(lrequest->ActualModelVersion()); + + lrequest->SetTrace(std::make_shared(ltrace)); +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "inference tracing not supported"); +#endif // TRITON_ENABLE_TRACING + } + + // We wrap the request in a unique pointer to ensure that it flows + // through inferencing with clear ownership. + std::unique_ptr ureq(lrequest); + + // Run inference... + tc::Status status = lserver->InferAsync(ureq); + + // If there is an error then must explicitly release any trace + // object associated with the inference request above. +#ifdef TRITON_ENABLE_TRACING + if (!status.IsOk()) { + ureq->ReleaseTrace(); + } +#endif // TRITON_ENABLE_TRACING + + // If there is an error then ureq will still have 'lrequest' and we + // must release it from unique_ptr since the caller should retain + // ownership when there is error. If there is not an error then ureq + // == nullptr and so this release is a nop. + ureq.release(); + + RETURN_IF_STATUS_ERROR(status); + return nullptr; // Success +} + +// +// TRITONSERVER_MetricFamily +// +TRITONSERVER_Error* +TRITONSERVER_MetricFamilyNew( + TRITONSERVER_MetricFamily** family, TRITONSERVER_MetricKind kind, + const char* name, const char* description) +{ +#ifdef TRITON_ENABLE_METRICS + try { + *family = reinterpret_cast( + new tc::MetricFamily(kind, name, description)); + } + catch (std::invalid_argument const& ex) { + // Catch invalid kinds passed to constructor + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG, ex.what()); + } + return nullptr; // Success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "metrics not supported"); +#endif // TRITON_ENABLE_METRICS +} + +TRITONSERVER_Error* +TRITONSERVER_MetricFamilyDelete(TRITONSERVER_MetricFamily* family) +{ +#ifdef TRITON_ENABLE_METRICS + auto lfamily = reinterpret_cast(family); + if (lfamily->NumMetrics() > 0) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "Must call MetricDelete on all dependent metrics before calling " + "MetricFamilyDelete."); + } + + delete lfamily; + return nullptr; // Success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "metrics not supported"); +#endif // TRITON_ENABLE_METRICS +} + +// +// TRITONSERVER_Metric +// +TRITONSERVER_Error* +TRITONSERVER_MetricNew( + TRITONSERVER_Metric** metric, TRITONSERVER_MetricFamily* family, + const TRITONSERVER_Parameter** labels, const uint64_t label_count) +{ +#ifdef TRITON_ENABLE_METRICS + std::vector labels_vec; + for (size_t i = 0; i < label_count; i++) { + labels_vec.emplace_back( + reinterpret_cast(labels[i])); + } + + try { + *metric = reinterpret_cast( + new tc::Metric(family, labels_vec)); + } + catch (std::invalid_argument const& ex) { + // Catch invalid kinds passed to constructor + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG, ex.what()); + } + + return nullptr; // Success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "metrics not supported"); +#endif // TRITON_ENABLE_METRICS +} + +TRITONSERVER_Error* +TRITONSERVER_MetricDelete(TRITONSERVER_Metric* metric) +{ +#ifdef TRITON_ENABLE_METRICS + auto lmetric = reinterpret_cast(metric); + if (lmetric->Family() == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "MetricFamily reference was invalidated before Metric was deleted. " + "Must call MetricDelete on all dependent metrics before calling " + "MetricFamilyDelete."); + } + + delete lmetric; + return nullptr; // success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "metrics not supported"); +#endif // TRITON_ENABLE_METRICS +} + +TRITONSERVER_Error* +TRITONSERVER_MetricValue(TRITONSERVER_Metric* metric, double* value) +{ +#ifdef TRITON_ENABLE_METRICS + return reinterpret_cast(metric)->Value(value); +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "metrics not supported"); +#endif // TRITON_ENABLE_METRICS +} + +TRITONSERVER_Error* +TRITONSERVER_MetricIncrement(TRITONSERVER_Metric* metric, double value) +{ +#ifdef TRITON_ENABLE_METRICS + return reinterpret_cast(metric)->Increment(value); +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "metrics not supported"); +#endif // TRITON_ENABLE_METRICS +} + +TRITONSERVER_Error* +TRITONSERVER_MetricSet(TRITONSERVER_Metric* metric, double value) +{ +#ifdef TRITON_ENABLE_METRICS + return reinterpret_cast(metric)->Set(value); +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "metrics not supported"); +#endif // TRITON_ENABLE_METRICS +} + +TRITONSERVER_Error* +TRITONSERVER_GetMetricKind( + TRITONSERVER_Metric* metric, TRITONSERVER_MetricKind* kind) +{ +#ifdef TRITON_ENABLE_METRICS + *kind = reinterpret_cast(metric)->Kind(); + return nullptr; // Success +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "metrics not supported"); +#endif // TRITON_ENABLE_METRICS +} + +} // extern C diff --git a/3rdparty/core-r22.12/src/tritonserver_apis.h b/3rdparty/core-r22.12/src/tritonserver_apis.h new file mode 100644 index 0000000000000000000000000000000000000000..5f2af2dc76a74b4d6af9453eb98570868f0677b4 --- /dev/null +++ b/3rdparty/core-r22.12/src/tritonserver_apis.h @@ -0,0 +1,38 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#define _COMPILING_TRITONSERVER 1 +#define _COMPILING_TRITONBACKEND 1 +#define _COMPILING_TRITONREPOAGENT 1 + +#include "triton/core/tritonbackend.h" +#include "triton/core/tritonrepoagent.h" +#include "triton/core/tritonserver.h" + +#undef _COMPILING_TRITONSERVER +#undef _COMPILING_TRITONBACKEND +#undef _COMPILING_TRITONREPOAGENT diff --git a/3rdparty/core-r22.12/src/tritonserver_stub.cc b/3rdparty/core-r22.12/src/tritonserver_stub.cc new file mode 100644 index 0000000000000000000000000000000000000000..402fb8313ac6ef31c0d7a1e4e008a53848fe0c36 --- /dev/null +++ b/3rdparty/core-r22.12/src/tritonserver_stub.cc @@ -0,0 +1,960 @@ +// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES{} LOSS OF USE, DATA, OR +// PROFITS{} OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#if defined(_MSC_VER) +#define TRITONAPI_DECLSPEC __declspec(dllexport) +#elif defined(__GNUC__) +#define TRITONAPI_DECLSPEC __attribute__((__visibility__("default"))) +#else +#define TRITONAPI_DECLSPEC +#endif + +extern "C" { +TRITONAPI_DECLSPEC void +TRITONSERVER_ApiVersion() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_DataTypeString() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_StringToDataType() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_DataTypeByteSize() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_MemoryTypeString() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ParameterTypeString() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ParameterNew() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ParameterBytesNew() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ParameterDelete() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InstanceGroupKindString() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_LogIsEnabled() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_LogMessage() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ErrorNew() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ErrorDelete() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ErrorCode() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ErrorCodeString() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ErrorMessage() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ResponseAllocatorNew() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ResponseAllocatorSetQueryFunction() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ResponseAllocatorSetBufferAttributesFunction() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ResponseAllocatorDelete() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_MessageNewFromSerializedJson() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_MessageDelete() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_MessageSerializeToJson() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_MetricsDelete() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_MetricsFormatted() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceTraceLevelString() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceTraceActivityString() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceTraceNew() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceTraceTensorNew() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceTraceDelete() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceTraceId() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceTraceParentId() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceTraceModelName() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceTraceModelVersion() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestNew() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestDelete() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestId() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestSetId() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestFlags() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestSetFlags() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestCorrelationId() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestCorrelationIdString() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestSetCorrelationId() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestSetCorrelationIdString() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestPriority() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestSetPriority() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestTimeoutMicroseconds() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestSetTimeoutMicroseconds() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestAddInput() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestAddRawInput() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestRemoveInput() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestRemoveAllInputs() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestAppendInputData() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestAppendInputDataWithHostPolicy() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestRemoveAllInputData() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestAddRequestedOutput() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestRemoveRequestedOutput() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestRemoveAllRequestedOutputs() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestSetReleaseCallback() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestSetResponseCallback() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceResponseDelete() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceResponseError() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceResponseModel() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceResponseId() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceResponseParameterCount() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceResponseParameter() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceResponseOutputCount() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceResponseOutput() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceResponseOutputClassificationLabel() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsNew() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsDelete() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetServerId() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetModelRepositoryPath() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetModelControlMode() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetStartupModel() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetStrictModelConfig() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetRateLimiterMode() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsAddRateLimiterResource() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetPinnedMemoryPoolByteSize() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetCudaMemoryPoolByteSize() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetResponseCacheByteSize() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetMinSupportedComputeCapability() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetExitOnError() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetStrictReadiness() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetExitTimeout() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetBufferManagerThreadCount() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetModelLoadThreadCount() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetLogFile() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetLogInfo() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetLogWarn() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetLogError() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetLogVerbose() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetLogFormat() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetMetrics() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetGpuMetrics() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetCpuMetrics() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetMetricsInterval() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetBackendDirectory() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetRepoAgentDirectory() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetModelLoadDeviceLimit() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetBackendConfig() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerOptionsSetHostPolicy() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_InferenceRequestAppendInputDataWithBufferAttributes() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_BufferAttributesNew() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_BufferAttributesDelete() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_BufferAttributesSetMemoryTypeId() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_BufferAttributesSetMemoryType() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_BufferAttributesSetCudaIpcHandle() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_BufferAttributesSetByteSize() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_BufferAttributesMemoryTypeId() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_BufferAttributesMemoryType() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_BufferAttributesCudaIpcHandle() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_BufferAttributesByteSize() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerNew() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerDelete() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerStop() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerPollModelRepository() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerIsLive() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerIsReady() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerModelIsReady() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerModelBatchProperties() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerModelTransactionProperties() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerMetadata() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerModelMetadata() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerModelStatistics() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerModelConfig() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerModelIndex() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerRegisterModelRepository() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerUnregisterModelRepository() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerLoadModel() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerLoadModelWithParameters() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerUnloadModel() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerUnloadModelAndDependents() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerMetrics() +{ +} +TRITONAPI_DECLSPEC void +TRITONSERVER_ServerInferAsync() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ApiVersion() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_MemoryManagerAllocate() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_MemoryManagerFree() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_InputProperties() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_InputPropertiesForHostPolicy() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_InputBuffer() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_InputBufferForHostPolicy() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_InputBufferAttributes() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_OutputBuffer() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_OutputBufferAttributes() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_RequestId() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_RequestCorrelationId() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_RequestCorrelationIdString() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_RequestFlags() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_RequestInputCount() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_RequestInputName() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_RequestInput() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_RequestInputByIndex() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_RequestOutputCount() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_RequestOutputName() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_RequestOutputBufferProperties() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_RequestRelease() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ResponseFactoryNew() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ResponseFactoryDelete() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ResponseFactorySendFlags() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ResponseNew() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ResponseNewFromFactory() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ResponseDelete() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ResponseSetStringParameter() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ResponseSetIntParameter() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ResponseSetBoolParameter() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ResponseOutput() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ResponseSend() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_StateNew() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_StateUpdate() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_StateBuffer() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_StateBufferAttributes() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_BackendName() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_BackendConfig() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_BackendExecutionPolicy() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_BackendSetExecutionPolicy() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_BackendArtifacts() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_BackendMemoryManager() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_BackendState() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_BackendSetState() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelName() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelVersion() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelRepository() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelConfig() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelAutoCompleteConfig() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelSetConfig() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelServer() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelBackend() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelState() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelSetState() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelInstanceName() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelInstanceKind() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelInstanceDeviceId() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelInstanceHostPolicy() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelInstanceIsPassive() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelInstanceProfileCount() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelInstanceProfileName() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelInstanceSecondaryDeviceCount() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelInstanceSecondaryDeviceProperties() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelInstanceModel() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelInstanceState() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelInstanceSetState() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelInstanceReportStatistics() +{ +} +TRITONAPI_DECLSPEC void +TRITONBACKEND_ModelInstanceReportBatchStatistics() +{ +} +TRITONAPI_DECLSPEC void +TRITONREPOAGENT_ApiVersion() +{ +} +TRITONAPI_DECLSPEC void +TRITONREPOAGENT_ModelRepositoryLocation() +{ +} + +TRITONAPI_DECLSPEC void +TRITONREPOAGENT_ModelRepositoryLocationAcquire() +{ +} + +TRITONAPI_DECLSPEC void +TRITONREPOAGENT_ModelRepositoryLocationRelease() +{ +} + +TRITONAPI_DECLSPEC void +TRITONREPOAGENT_ModelRepositoryUpdate() +{ +} + +TRITONAPI_DECLSPEC void +TRITONREPOAGENT_ModelParameterCount() +{ +} + +TRITONAPI_DECLSPEC void +TRITONREPOAGENT_ModelParameter() +{ +} + +TRITONAPI_DECLSPEC void +TRITONREPOAGENT_ModelConfig() +{ +} + +TRITONAPI_DECLSPEC void +TRITONREPOAGENT_ModelState() +{ +} + +TRITONAPI_DECLSPEC void +TRITONREPOAGENT_ModelSetState() +{ +} + +TRITONAPI_DECLSPEC void +TRITONREPOAGENT_State() +{ +} + +TRITONAPI_DECLSPEC void +TRITONREPOAGENT_SetState() +{ +} + +TRITONAPI_DECLSPEC void +TRITONSERVER_MetricFamilyNew() +{ +} + +TRITONAPI_DECLSPEC void +TRITONSERVER_MetricFamilyDelete() +{ +} + +TRITONAPI_DECLSPEC void +TRITONSERVER_MetricNew() +{ +} + +TRITONAPI_DECLSPEC void +TRITONSERVER_MetricDelete() +{ +} + +TRITONAPI_DECLSPEC void +TRITONSERVER_MetricValue() +{ +} + +TRITONAPI_DECLSPEC void +TRITONSERVER_MetricIncrement() +{ +} + +TRITONAPI_DECLSPEC void +TRITONSERVER_MetricSet() +{ +} + +TRITONAPI_DECLSPEC void +TRITONSERVER_GetMetricKind() +{ +} + +TRITONAPI_DECLSPEC void +TRITONBACKEND_BackendAttributeAddPreferredInstanceGroup() +{ +} + +} /* extern "C" */ diff --git a/3rdparty/googletest-1.13.0/.clang-format b/3rdparty/googletest-1.13.0/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..5b9bfe6d224232981ada90cee232c716afbdf09d --- /dev/null +++ b/3rdparty/googletest-1.13.0/.clang-format @@ -0,0 +1,4 @@ +# Run manually to reformat a file: +# clang-format -i --style=file +Language: Cpp +BasedOnStyle: Google diff --git a/3rdparty/googletest-1.13.0/.github/ISSUE_TEMPLATE/00-bug_report.yml b/3rdparty/googletest-1.13.0/.github/ISSUE_TEMPLATE/00-bug_report.yml new file mode 100644 index 0000000000000000000000000000000000000000..586779ad2d618299eff5b68f9f8a3da6013934b8 --- /dev/null +++ b/3rdparty/googletest-1.13.0/.github/ISSUE_TEMPLATE/00-bug_report.yml @@ -0,0 +1,53 @@ +name: Bug Report +description: Let us know that something does not work as expected. +title: "[Bug]: Please title this bug report" +body: + - type: textarea + id: what-happened + attributes: + label: Describe the issue + description: What happened, and what did you expect to happen? + validations: + required: true + - type: textarea + id: steps + attributes: + label: Steps to reproduce the problem + description: It is important that we are able to reproduce the problem that you are experiencing. Please provide all code and relevant steps to reproduce the problem, including your `BUILD`/`CMakeLists.txt` file and build commands. Links to a GitHub branch or [godbolt.org](https://godbolt.org/) that demonstrate the problem are also helpful. + validations: + required: true + - type: textarea + id: version + attributes: + label: What version of GoogleTest are you using? + description: Please include the output of `git rev-parse HEAD` or the GoogleTest release version number that you are using. + validations: + required: true + - type: textarea + id: os + attributes: + label: What operating system and version are you using? + description: If you are using a Linux distribution please include the name and version of the distribution as well. + validations: + required: true + - type: textarea + id: compiler + attributes: + label: What compiler and version are you using? + description: Please include the output of `gcc -v` or `clang -v`, or the equivalent for your compiler. + validations: + required: true + - type: textarea + id: buildsystem + attributes: + label: What build system are you using? + description: Please include the output of `bazel --version` or `cmake --version`, or the equivalent for your build system. + validations: + required: true + - type: textarea + id: additional + attributes: + label: Additional context + description: Add any other context about the problem here. + validations: + required: false diff --git a/3rdparty/googletest-1.13.0/.github/ISSUE_TEMPLATE/10-feature_request.yml b/3rdparty/googletest-1.13.0/.github/ISSUE_TEMPLATE/10-feature_request.yml new file mode 100644 index 0000000000000000000000000000000000000000..91ad0417702e1971fab9e8d57df9a3efbde9240d --- /dev/null +++ b/3rdparty/googletest-1.13.0/.github/ISSUE_TEMPLATE/10-feature_request.yml @@ -0,0 +1,33 @@ +name: Feature request +description: Propose a new feature. +title: "[FR]: Please title this feature request" +labels: "enhancement" +body: + - type: textarea + id: version + attributes: + label: Does the feature exist in the most recent commit? + description: We recommend using the latest commit from GitHub in your projects. + validations: + required: true + - type: textarea + id: why + attributes: + label: Why do we need this feature? + description: Ideally, explain why a combination of existing features cannot be used instead. + validations: + required: true + - type: textarea + id: proposal + attributes: + label: Describe the proposal. + description: Include a detailed description of the feature, with usage examples. + validations: + required: true + - type: textarea + id: platform + attributes: + label: Is the feature specific to an operating system, compiler, or build system version? + description: If it is, please specify which versions. + validations: + required: true diff --git a/3rdparty/googletest-1.13.0/.github/ISSUE_TEMPLATE/config.yml b/3rdparty/googletest-1.13.0/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..65170d10a78231455bed85e929bc008927445644 --- /dev/null +++ b/3rdparty/googletest-1.13.0/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: Get Help + url: https://github.com/google/googletest/discussions + about: Please ask and answer questions here. diff --git a/3rdparty/googletest-1.13.0/.github/workflows/gtest-ci.yml b/3rdparty/googletest-1.13.0/.github/workflows/gtest-ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..03a8cc5e287b47bacc3b9ae7dc0d5b966cf4debe --- /dev/null +++ b/3rdparty/googletest-1.13.0/.github/workflows/gtest-ci.yml @@ -0,0 +1,43 @@ +name: ci + +on: + push: + pull_request: + +env: + BAZEL_CXXOPTS: -std=c++14 + +jobs: + Linux: + runs-on: ubuntu-latest + steps: + + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Tests + run: bazel test --cxxopt=-std=c++14 --features=external_include_paths --test_output=errors ... + + macOS: + runs-on: macos-latest + steps: + + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Tests + run: bazel test --cxxopt=-std=c++14 --features=external_include_paths --test_output=errors ... + + + Windows: + runs-on: windows-latest + steps: + + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Tests + run: bazel test --cxxopt=/std:c++14 --features=external_include_paths --test_output=errors ... diff --git a/3rdparty/googletest-1.13.0/.gitignore b/3rdparty/googletest-1.13.0/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f08cb72a33cd199478f41be1bd487f916330472c --- /dev/null +++ b/3rdparty/googletest-1.13.0/.gitignore @@ -0,0 +1,84 @@ +# Ignore CI build directory +build/ +xcuserdata +cmake-build-debug/ +.idea/ +bazel-bin +bazel-genfiles +bazel-googletest +bazel-out +bazel-testlogs +# python +*.pyc + +# Visual Studio files +.vs +*.sdf +*.opensdf +*.VC.opendb +*.suo +*.user +_ReSharper.Caches/ +Win32-Debug/ +Win32-Release/ +x64-Debug/ +x64-Release/ + +# Ignore autoconf / automake files +Makefile.in +aclocal.m4 +configure +build-aux/ +autom4te.cache/ +googletest/m4/libtool.m4 +googletest/m4/ltoptions.m4 +googletest/m4/ltsugar.m4 +googletest/m4/ltversion.m4 +googletest/m4/lt~obsolete.m4 +googlemock/m4 + +# Ignore generated directories. +googlemock/fused-src/ +googletest/fused-src/ + +# macOS files +.DS_Store +googletest/.DS_Store +googletest/xcode/.DS_Store + +# Ignore cmake generated directories and files. +CMakeFiles +CTestTestfile.cmake +Makefile +cmake_install.cmake +googlemock/CMakeFiles +googlemock/CTestTestfile.cmake +googlemock/Makefile +googlemock/cmake_install.cmake +googlemock/gtest +/bin +/googlemock/gmock.dir +/googlemock/gmock_main.dir +/googlemock/RUN_TESTS.vcxproj.filters +/googlemock/RUN_TESTS.vcxproj +/googlemock/INSTALL.vcxproj.filters +/googlemock/INSTALL.vcxproj +/googlemock/gmock_main.vcxproj.filters +/googlemock/gmock_main.vcxproj +/googlemock/gmock.vcxproj.filters +/googlemock/gmock.vcxproj +/googlemock/gmock.sln +/googlemock/ALL_BUILD.vcxproj.filters +/googlemock/ALL_BUILD.vcxproj +/lib +/Win32 +/ZERO_CHECK.vcxproj.filters +/ZERO_CHECK.vcxproj +/RUN_TESTS.vcxproj.filters +/RUN_TESTS.vcxproj +/INSTALL.vcxproj.filters +/INSTALL.vcxproj +/googletest-distribution.sln +/CMakeCache.txt +/ALL_BUILD.vcxproj.filters +/ALL_BUILD.vcxproj diff --git a/3rdparty/googletest-1.13.0/BUILD.bazel b/3rdparty/googletest-1.13.0/BUILD.bazel new file mode 100644 index 0000000000000000000000000000000000000000..ac62251e10172614d93b385777e52defc187cac4 --- /dev/null +++ b/3rdparty/googletest-1.13.0/BUILD.bazel @@ -0,0 +1,218 @@ +# Copyright 2017 Google Inc. +# All Rights Reserved. +# +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Bazel Build for Google C++ Testing Framework(Google Test) + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +config_setting( + name = "qnx", + constraint_values = ["@platforms//os:qnx"], +) + +config_setting( + name = "windows", + constraint_values = ["@platforms//os:windows"], +) + +config_setting( + name = "freebsd", + constraint_values = ["@platforms//os:freebsd"], +) + +config_setting( + name = "openbsd", + constraint_values = ["@platforms//os:openbsd"], +) + +config_setting( + name = "msvc_compiler", + flag_values = { + "@bazel_tools//tools/cpp:compiler": "msvc-cl", + }, + visibility = [":__subpackages__"], +) + +config_setting( + name = "has_absl", + values = {"define": "absl=1"}, +) + +# Library that defines the FRIEND_TEST macro. +cc_library( + name = "gtest_prod", + hdrs = ["googletest/include/gtest/gtest_prod.h"], + includes = ["googletest/include"], +) + +# Google Test including Google Mock +cc_library( + name = "gtest", + srcs = glob( + include = [ + "googletest/src/*.cc", + "googletest/src/*.h", + "googletest/include/gtest/**/*.h", + "googlemock/src/*.cc", + "googlemock/include/gmock/**/*.h", + ], + exclude = [ + "googletest/src/gtest-all.cc", + "googletest/src/gtest_main.cc", + "googlemock/src/gmock-all.cc", + "googlemock/src/gmock_main.cc", + ], + ), + hdrs = glob([ + "googletest/include/gtest/*.h", + "googlemock/include/gmock/*.h", + ]), + copts = select({ + ":qnx": [], + ":windows": [], + "//conditions:default": ["-pthread"], + }), + defines = select({ + ":has_absl": ["GTEST_HAS_ABSL=1"], + "//conditions:default": [], + }), + features = select({ + ":windows": ["windows_export_all_symbols"], + "//conditions:default": [], + }), + includes = [ + "googlemock", + "googlemock/include", + "googletest", + "googletest/include", + ], + linkopts = select({ + ":qnx": ["-lregex"], + ":windows": [], + ":freebsd": [ + "-lm", + "-pthread", + ], + ":openbsd": [ + "-lm", + "-pthread", + ], + "//conditions:default": ["-pthread"], + }), + deps = select({ + ":has_absl": [ + "@com_google_absl//absl/debugging:failure_signal_handler", + "@com_google_absl//absl/debugging:stacktrace", + "@com_google_absl//absl/debugging:symbolize", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/flags:reflection", + "@com_google_absl//absl/flags:usage", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:any", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", + "@com_googlesource_code_re2//:re2", + ], + "//conditions:default": [], + }), +) + +cc_library( + name = "gtest_main", + srcs = ["googlemock/src/gmock_main.cc"], + features = select({ + ":windows": ["windows_export_all_symbols"], + "//conditions:default": [], + }), + deps = [":gtest"], +) + +# The following rules build samples of how to use gTest. +cc_library( + name = "gtest_sample_lib", + srcs = [ + "googletest/samples/sample1.cc", + "googletest/samples/sample2.cc", + "googletest/samples/sample4.cc", + ], + hdrs = [ + "googletest/samples/prime_tables.h", + "googletest/samples/sample1.h", + "googletest/samples/sample2.h", + "googletest/samples/sample3-inl.h", + "googletest/samples/sample4.h", + ], + features = select({ + ":windows": ["windows_export_all_symbols"], + "//conditions:default": [], + }), +) + +cc_test( + name = "gtest_samples", + size = "small", + # All Samples except: + # sample9 (main) + # sample10 (main and takes a command line option and needs to be separate) + srcs = [ + "googletest/samples/sample1_unittest.cc", + "googletest/samples/sample2_unittest.cc", + "googletest/samples/sample3_unittest.cc", + "googletest/samples/sample4_unittest.cc", + "googletest/samples/sample5_unittest.cc", + "googletest/samples/sample6_unittest.cc", + "googletest/samples/sample7_unittest.cc", + "googletest/samples/sample8_unittest.cc", + ], + linkstatic = 0, + deps = [ + "gtest_sample_lib", + ":gtest_main", + ], +) + +cc_test( + name = "sample9_unittest", + size = "small", + srcs = ["googletest/samples/sample9_unittest.cc"], + deps = [":gtest"], +) + +cc_test( + name = "sample10_unittest", + size = "small", + srcs = ["googletest/samples/sample10_unittest.cc"], + deps = [":gtest"], +) diff --git a/3rdparty/googletest-1.13.0/CONTRIBUTING.md b/3rdparty/googletest-1.13.0/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..de14c8159b17e367dd0098c7202e13439a06ff89 --- /dev/null +++ b/3rdparty/googletest-1.13.0/CONTRIBUTING.md @@ -0,0 +1,131 @@ +# How to become a contributor and submit your own code + +## Contributor License Agreements + +We'd love to accept your patches! Before we can take them, we have to jump a +couple of legal hurdles. + +Please fill out either the individual or corporate Contributor License Agreement +(CLA). + +* If you are an individual writing original source code and you're sure you + own the intellectual property, then you'll need to sign an + [individual CLA](https://developers.google.com/open-source/cla/individual). +* If you work for a company that wants to allow you to contribute your work, + then you'll need to sign a + [corporate CLA](https://developers.google.com/open-source/cla/corporate). + +Follow either of the two links above to access the appropriate CLA and +instructions for how to sign and return it. Once we receive it, we'll be able to +accept your pull requests. + +## Are you a Googler? + +If you are a Googler, please make an attempt to submit an internal contribution +rather than a GitHub Pull Request. If you are not able to submit internally, a +PR is acceptable as an alternative. + +## Contributing A Patch + +1. Submit an issue describing your proposed change to the + [issue tracker](https://github.com/google/googletest/issues). +2. Please don't mix more than one logical change per submittal, because it + makes the history hard to follow. If you want to make a change that doesn't + have a corresponding issue in the issue tracker, please create one. +3. Also, coordinate with team members that are listed on the issue in question. + This ensures that work isn't being duplicated and communicating your plan + early also generally leads to better patches. +4. If your proposed change is accepted, and you haven't already done so, sign a + Contributor License Agreement + ([see details above](#contributor-license-agreements)). +5. Fork the desired repo, develop and test your code changes. +6. Ensure that your code adheres to the existing style in the sample to which + you are contributing. +7. Ensure that your code has an appropriate set of unit tests which all pass. +8. Submit a pull request. + +## The Google Test and Google Mock Communities + +The Google Test community exists primarily through the +[discussion group](http://groups.google.com/group/googletestframework) and the +GitHub repository. Likewise, the Google Mock community exists primarily through +their own [discussion group](http://groups.google.com/group/googlemock). You are +definitely encouraged to contribute to the discussion and you can also help us +to keep the effectiveness of the group high by following and promoting the +guidelines listed here. + +### Please Be Friendly + +Showing courtesy and respect to others is a vital part of the Google culture, +and we strongly encourage everyone participating in Google Test development to +join us in accepting nothing less. Of course, being courteous is not the same as +failing to constructively disagree with each other, but it does mean that we +should be respectful of each other when enumerating the 42 technical reasons +that a particular proposal may not be the best choice. There's never a reason to +be antagonistic or dismissive toward anyone who is sincerely trying to +contribute to a discussion. + +Sure, C++ testing is serious business and all that, but it's also a lot of fun. +Let's keep it that way. Let's strive to be one of the friendliest communities in +all of open source. + +As always, discuss Google Test in the official GoogleTest discussion group. You +don't have to actually submit code in order to sign up. Your participation +itself is a valuable contribution. + +## Style + +To keep the source consistent, readable, diffable and easy to merge, we use a +fairly rigid coding style, as defined by the +[google-styleguide](https://github.com/google/styleguide) project. All patches +will be expected to conform to the style outlined +[here](https://google.github.io/styleguide/cppguide.html). Use +[.clang-format](https://github.com/google/googletest/blob/main/.clang-format) to +check your formatting. + +## Requirements for Contributors + +If you plan to contribute a patch, you need to build Google Test, Google Mock, +and their own tests from a git checkout, which has further requirements: + +* [Python](https://www.python.org/) v2.3 or newer (for running some of the + tests and re-generating certain source files from templates) +* [CMake](https://cmake.org/) v2.8.12 or newer + +## Developing Google Test and Google Mock + +This section discusses how to make your own changes to the Google Test project. + +### Testing Google Test and Google Mock Themselves + +To make sure your changes work as intended and don't break existing +functionality, you'll want to compile and run Google Test and GoogleMock's own +tests. For that you can use CMake: + + mkdir mybuild + cd mybuild + cmake -Dgtest_build_tests=ON -Dgmock_build_tests=ON ${GTEST_REPO_DIR} + +To choose between building only Google Test or Google Mock, you may modify your +cmake command to be one of each + + cmake -Dgtest_build_tests=ON ${GTEST_DIR} # sets up Google Test tests + cmake -Dgmock_build_tests=ON ${GMOCK_DIR} # sets up Google Mock tests + +Make sure you have Python installed, as some of Google Test's tests are written +in Python. If the cmake command complains about not being able to find Python +(`Could NOT find PythonInterp (missing: PYTHON_EXECUTABLE)`), try telling it +explicitly where your Python executable can be found: + + cmake -DPYTHON_EXECUTABLE=path/to/python ... + +Next, you can build Google Test and / or Google Mock and all desired tests. On +\*nix, this is usually done by + + make + +To run the tests, do + + make test + +All tests should pass. diff --git a/3rdparty/googletest-1.13.0/CONTRIBUTORS b/3rdparty/googletest-1.13.0/CONTRIBUTORS new file mode 100644 index 0000000000000000000000000000000000000000..77397a5b53fea5352f8af38bdb1c4a0ce0e30d66 --- /dev/null +++ b/3rdparty/googletest-1.13.0/CONTRIBUTORS @@ -0,0 +1,65 @@ +# This file contains a list of people who've made non-trivial +# contribution to the Google C++ Testing Framework project. People +# who commit code to the project are encouraged to add their names +# here. Please keep the list sorted by first names. + +Ajay Joshi +Balázs Dán +Benoit Sigoure +Bharat Mediratta +Bogdan Piloca +Chandler Carruth +Chris Prince +Chris Taylor +Dan Egnor +Dave MacLachlan +David Anderson +Dean Sturtevant +Eric Roman +Gene Volovich +Hady Zalek +Hal Burch +Jeffrey Yasskin +Jim Keller +Joe Walnes +Jon Wray +Jói Sigurðsson +Keir Mierle +Keith Ray +Kenton Varda +Kostya Serebryany +Krystian Kuzniarek +Lev Makhlis +Manuel Klimek +Mario Tanev +Mark Paskin +Markus Heule +Martijn Vels +Matthew Simmons +Mika Raento +Mike Bland +Miklós Fazekas +Neal Norwitz +Nermin Ozkiranartli +Owen Carlsen +Paneendra Ba +Pasi Valminen +Patrick Hanna +Patrick Riley +Paul Menage +Peter Kaminski +Piotr Kaminski +Preston Jackson +Rainer Klaffenboeck +Russ Cox +Russ Rufer +Sean Mcafee +Sigurður Ásgeirsson +Sverre Sundsdal +Szymon Sobik +Takeshi Yoshino +Tracy Bialik +Vadim Berman +Vlad Losev +Wolfgang Klier +Zhanyong Wan diff --git a/3rdparty/googletest-1.13.0/LICENSE b/3rdparty/googletest-1.13.0/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..1941a11f8ce94389160b458927a29ba217542818 --- /dev/null +++ b/3rdparty/googletest-1.13.0/LICENSE @@ -0,0 +1,28 @@ +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/3rdparty/googletest-1.13.0/README.md b/3rdparty/googletest-1.13.0/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cd89abb2d2daea7d0feef40f2496106a9bfd31f4 --- /dev/null +++ b/3rdparty/googletest-1.13.0/README.md @@ -0,0 +1,117 @@ +# GoogleTest + +### Announcements + +#### Live at Head + +GoogleTest now follows the +[Abseil Live at Head philosophy](https://abseil.io/about/philosophy#upgrade-support). +We recommend +[updating to the latest commit in the `main` branch as often as possible](https://github.com/abseil/abseil-cpp/blob/master/FAQ.md#what-is-live-at-head-and-how-do-i-do-it). + +#### Documentation Updates + +Our documentation is now live on GitHub Pages at +https://google.github.io/googletest/. We recommend browsing the documentation on +GitHub Pages rather than directly in the repository. + +#### Release 1.12.1 + +[Release 1.12.1](https://github.com/google/googletest/releases/tag/release-1.12.1) +is now available. + +The 1.12.x branch will be the last to support C++11. Future releases will +require at least C++14. + +#### Coming Soon + +* We are planning to take a dependency on + [Abseil](https://github.com/abseil/abseil-cpp). +* More documentation improvements are planned. + +## Welcome to **GoogleTest**, Google's C++ test framework! + +This repository is a merger of the formerly separate GoogleTest and GoogleMock +projects. These were so closely related that it makes sense to maintain and +release them together. + +### Getting Started + +See the [GoogleTest User's Guide](https://google.github.io/googletest/) for +documentation. We recommend starting with the +[GoogleTest Primer](https://google.github.io/googletest/primer.html). + +More information about building GoogleTest can be found at +[googletest/README.md](googletest/README.md). + +## Features + +* An [xUnit](https://en.wikipedia.org/wiki/XUnit) test framework. +* Test discovery. +* A rich set of assertions. +* User-defined assertions. +* Death tests. +* Fatal and non-fatal failures. +* Value-parameterized tests. +* Type-parameterized tests. +* Various options for running the tests. +* XML test report generation. + +## Supported Platforms + +GoogleTest follows Google's +[Foundational C++ Support Policy](https://opensource.google/documentation/policies/cplusplus-support). +See +[this table](https://github.com/google/oss-policies-info/blob/main/foundational-cxx-support-matrix.md) +for a list of currently supported versions compilers, platforms, and build +tools. + +## Who Is Using GoogleTest? + +In addition to many internal projects at Google, GoogleTest is also used by the +following notable projects: + +* The [Chromium projects](http://www.chromium.org/) (behind the Chrome browser + and Chrome OS). +* The [LLVM](http://llvm.org/) compiler. +* [Protocol Buffers](https://github.com/google/protobuf), Google's data + interchange format. +* The [OpenCV](http://opencv.org/) computer vision library. + +## Related Open Source Projects + +[GTest Runner](https://github.com/nholthaus/gtest-runner) is a Qt5 based +automated test-runner and Graphical User Interface with powerful features for +Windows and Linux platforms. + +[GoogleTest UI](https://github.com/ospector/gtest-gbar) is a test runner that +runs your test binary, allows you to track its progress via a progress bar, and +displays a list of test failures. Clicking on one shows failure text. GoogleTest +UI is written in C#. + +[GTest TAP Listener](https://github.com/kinow/gtest-tap-listener) is an event +listener for GoogleTest that implements the +[TAP protocol](https://en.wikipedia.org/wiki/Test_Anything_Protocol) for test +result output. If your test runner understands TAP, you may find it useful. + +[gtest-parallel](https://github.com/google/gtest-parallel) is a test runner that +runs tests from your binary in parallel to provide significant speed-up. + +[GoogleTest Adapter](https://marketplace.visualstudio.com/items?itemName=DavidSchuldenfrei.gtest-adapter) +is a VS Code extension allowing to view GoogleTest in a tree view and run/debug +your tests. + +[C++ TestMate](https://github.com/matepek/vscode-catch2-test-adapter) is a VS +Code extension allowing to view GoogleTest in a tree view and run/debug your +tests. + +[Cornichon](https://pypi.org/project/cornichon/) is a small Gherkin DSL parser +that generates stub code for GoogleTest. + +## Contributing Changes + +Please read +[`CONTRIBUTING.md`](https://github.com/google/googletest/blob/main/CONTRIBUTING.md) +for details on how to contribute to this project. + +Happy testing! diff --git a/3rdparty/googletest-1.13.0/WORKSPACE b/3rdparty/googletest-1.13.0/WORKSPACE new file mode 100644 index 0000000000000000000000000000000000000000..0f10a6a9a8691036391e8acd814e500a33f5a7bf --- /dev/null +++ b/3rdparty/googletest-1.13.0/WORKSPACE @@ -0,0 +1,40 @@ +workspace(name = "com_google_googletest") + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "com_google_absl", # 2023-01-10T21:08:25Z + sha256 = "f9a4e749f42c386a32a90fddf0e2913ed408d10c42f7f33ccf4c59ac4f0d1d05", + strip_prefix = "abseil-cpp-52835439ca90d86b27bf8cd1708296e95604d724", + urls = ["https://github.com/abseil/abseil-cpp/archive/52835439ca90d86b27bf8cd1708296e95604d724.zip"], +) + +# Note this must use a commit from the `abseil` branch of the RE2 project. +# https://github.com/google/re2/tree/abseil +http_archive( + name = "com_googlesource_code_re2", # 2022-12-21T14:29:10Z + sha256 = "b9ce3a51beebb38534d11d40f8928d40509b9e18a735f6a4a97ad3d014c87cb5", + strip_prefix = "re2-d0b1f8f2ecc2ea74956c7608b6f915175314ff0e", + urls = ["https://github.com/google/re2/archive/d0b1f8f2ecc2ea74956c7608b6f915175314ff0e.zip"], +) + +http_archive( + name = "rules_python", # 2023-01-10T22:00:51Z + sha256 = "5de54486a60ad8948dabe49605bb1c08053e04001a431ab3e96745b4d97a4419", + strip_prefix = "rules_python-70cce26432187a60b4e950118791385e6fb3c26f", + urls = ["https://github.com/bazelbuild/rules_python/archive/70cce26432187a60b4e950118791385e6fb3c26f.zip"], +) + +http_archive( + name = "bazel_skylib", # 2022-11-16T18:29:32Z + sha256 = "a22290c26d29d3ecca286466f7f295ac6cbe32c0a9da3a91176a90e0725e3649", + strip_prefix = "bazel-skylib-5bfcb1a684550626ce138fe0fe8f5f702b3764c3", + urls = ["https://github.com/bazelbuild/bazel-skylib/archive/5bfcb1a684550626ce138fe0fe8f5f702b3764c3.zip"], +) + +http_archive( + name = "platforms", # 2022-11-09T19:18:22Z + sha256 = "b4a3b45dc4202e2b3e34e3bc49d2b5b37295fc23ea58d88fb9e01f3642ad9b55", + strip_prefix = "platforms-3fbc687756043fb58a407c2ea8c944bc2fe1d922", + urls = ["https://github.com/bazelbuild/platforms/archive/3fbc687756043fb58a407c2ea8c944bc2fe1d922.zip"], +) diff --git a/3rdparty/googletest-1.13.0/ci/linux-presubmit.sh b/3rdparty/googletest-1.13.0/ci/linux-presubmit.sh new file mode 100644 index 0000000000000000000000000000000000000000..4eb5bbe4a1dac8923fc0aa2f2d49cc38f258ec4c --- /dev/null +++ b/3rdparty/googletest-1.13.0/ci/linux-presubmit.sh @@ -0,0 +1,134 @@ +#!/bin/bash +# +# Copyright 2020, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +set -euox pipefail + +readonly LINUX_LATEST_CONTAINER="gcr.io/google.com/absl-177019/linux_hybrid-latest:20220217" +readonly LINUX_GCC_FLOOR_CONTAINER="gcr.io/google.com/absl-177019/linux_gcc-floor:20220621" + +if [[ -z ${GTEST_ROOT:-} ]]; then + GTEST_ROOT="$(realpath $(dirname ${0})/..)" +fi + +if [[ -z ${STD:-} ]]; then + STD="c++14 c++17 c++20" +fi + +# Test the CMake build +for cc in /usr/local/bin/gcc /opt/llvm/clang/bin/clang; do + for cmake_off_on in OFF ON; do + time docker run \ + --volume="${GTEST_ROOT}:/src:ro" \ + --tmpfs="/build:exec" \ + --workdir="/build" \ + --rm \ + --env="CC=${cc}" \ + --env="CXX_FLAGS=\"-Werror -Wdeprecated\"" \ + ${LINUX_LATEST_CONTAINER} \ + /bin/bash -c " + cmake /src \ + -DCMAKE_CXX_STANDARD=14 \ + -Dgtest_build_samples=ON \ + -Dgtest_build_tests=ON \ + -Dgmock_build_tests=ON \ + -Dcxx_no_exception=${cmake_off_on} \ + -Dcxx_no_rtti=${cmake_off_on} && \ + make -j$(nproc) && \ + ctest -j$(nproc) --output-on-failure" + done +done + +# Do one test with an older version of GCC +time docker run \ + --volume="${GTEST_ROOT}:/src:ro" \ + --workdir="/src" \ + --rm \ + --env="CC=/usr/local/bin/gcc" \ + --env="BAZEL_CXXOPTS=-std=c++14" \ + ${LINUX_GCC_FLOOR_CONTAINER} \ + /usr/local/bin/bazel test ... \ + --copt="-Wall" \ + --copt="-Werror" \ + --copt="-Wuninitialized" \ + --copt="-Wno-error=pragmas" \ + --distdir="/bazel-distdir" \ + --features=external_include_paths \ + --keep_going \ + --show_timestamps \ + --test_output=errors + +# Test GCC +for std in ${STD}; do + for absl in 0 1; do + time docker run \ + --volume="${GTEST_ROOT}:/src:ro" \ + --workdir="/src" \ + --rm \ + --env="CC=/usr/local/bin/gcc" \ + --env="BAZEL_CXXOPTS=-std=${std}" \ + ${LINUX_LATEST_CONTAINER} \ + /usr/local/bin/bazel test ... \ + --copt="-Wall" \ + --copt="-Werror" \ + --copt="-Wuninitialized" \ + --define="absl=${absl}" \ + --distdir="/bazel-distdir" \ + --features=external_include_paths \ + --keep_going \ + --show_timestamps \ + --test_output=errors + done +done + +# Test Clang +for std in ${STD}; do + for absl in 0 1; do + time docker run \ + --volume="${GTEST_ROOT}:/src:ro" \ + --workdir="/src" \ + --rm \ + --env="CC=/opt/llvm/clang/bin/clang" \ + --env="BAZEL_CXXOPTS=-std=${std}" \ + ${LINUX_LATEST_CONTAINER} \ + /usr/local/bin/bazel test ... \ + --copt="--gcc-toolchain=/usr/local" \ + --copt="-Wall" \ + --copt="-Werror" \ + --copt="-Wuninitialized" \ + --define="absl=${absl}" \ + --distdir="/bazel-distdir" \ + --features=external_include_paths \ + --keep_going \ + --linkopt="--gcc-toolchain=/usr/local" \ + --show_timestamps \ + --test_output=errors + done +done diff --git a/3rdparty/googletest-1.13.0/ci/macos-presubmit.sh b/3rdparty/googletest-1.13.0/ci/macos-presubmit.sh new file mode 100644 index 0000000000000000000000000000000000000000..8f35df58d2baa1b278f468608deb517d05150a3a --- /dev/null +++ b/3rdparty/googletest-1.13.0/ci/macos-presubmit.sh @@ -0,0 +1,75 @@ +#!/bin/bash +# +# Copyright 2020, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +set -euox pipefail + +if [[ -z ${GTEST_ROOT:-} ]]; then + GTEST_ROOT="$(realpath $(dirname ${0})/..)" +fi + +# Test the CMake build +for cmake_off_on in OFF ON; do + BUILD_DIR=$(mktemp -d build_dir.XXXXXXXX) + cd ${BUILD_DIR} + time cmake ${GTEST_ROOT} \ + -DCMAKE_CXX_STANDARD=14 \ + -Dgtest_build_samples=ON \ + -Dgtest_build_tests=ON \ + -Dgmock_build_tests=ON \ + -Dcxx_no_exception=${cmake_off_on} \ + -Dcxx_no_rtti=${cmake_off_on} + time make + time ctest -j$(nproc) --output-on-failure +done + +# Test the Bazel build + +# If we are running on Kokoro, check for a versioned Bazel binary. +KOKORO_GFILE_BAZEL_BIN="bazel-5.1.1-darwin-x86_64" +if [[ ${KOKORO_GFILE_DIR:-} ]] && [[ -f ${KOKORO_GFILE_DIR}/${KOKORO_GFILE_BAZEL_BIN} ]]; then + BAZEL_BIN="${KOKORO_GFILE_DIR}/${KOKORO_GFILE_BAZEL_BIN}" + chmod +x ${BAZEL_BIN} +else + BAZEL_BIN="bazel" +fi + +cd ${GTEST_ROOT} +for absl in 0 1; do + ${BAZEL_BIN} test ... \ + --copt="-Wall" \ + --copt="-Werror" \ + --cxxopt="-std=c++14" \ + --define="absl=${absl}" \ + --features=external_include_paths \ + --keep_going \ + --show_timestamps \ + --test_output=errors +done diff --git a/3rdparty/googletest-1.13.0/ci/windows-presubmit.bat b/3rdparty/googletest-1.13.0/ci/windows-presubmit.bat new file mode 100644 index 0000000000000000000000000000000000000000..8668ff3594b9fb31e5baf8c1f3c6de7946377b62 --- /dev/null +++ b/3rdparty/googletest-1.13.0/ci/windows-presubmit.bat @@ -0,0 +1,56 @@ +SETLOCAL ENABLEDELAYEDEXPANSION + +SET BAZEL_EXE=%KOKORO_GFILE_DIR%\bazel-5.1.1-windows-x86_64.exe + +SET PATH=C:\Python37;%PATH% +SET BAZEL_PYTHON=C:\python37\python.exe +SET BAZEL_SH=C:\tools\msys64\usr\bin\bash.exe +SET CMAKE_BIN="C:\Program Files\CMake\bin\cmake.exe" +SET CTEST_BIN="C:\Program Files\CMake\bin\ctest.exe" +SET CTEST_OUTPUT_ON_FAILURE=1 + +IF EXIST git\googletest ( + CD git\googletest +) ELSE IF EXIST github\googletest ( + CD github\googletest +) + +IF %errorlevel% neq 0 EXIT /B 1 + +:: ---------------------------------------------------------------------------- +:: CMake Visual Studio 15 2017 Win64 +MKDIR cmake_msvc2017 +CD cmake_msvc2017 + +%CMAKE_BIN% .. ^ + -G "Visual Studio 15 2017 Win64" ^ + -DPYTHON_EXECUTABLE:FILEPATH=c:\python37\python.exe ^ + -DPYTHON_INCLUDE_DIR:PATH=c:\python37\include ^ + -DPYTHON_LIBRARY:FILEPATH=c:\python37\lib\site-packages\pip ^ + -Dgtest_build_samples=ON ^ + -Dgtest_build_tests=ON ^ + -Dgmock_build_tests=ON +IF %errorlevel% neq 0 EXIT /B 1 + +%CMAKE_BIN% --build . --target ALL_BUILD --config Debug -- -maxcpucount +IF %errorlevel% neq 0 EXIT /B 1 + +%CTEST_BIN% -C Debug --timeout 600 +IF %errorlevel% neq 0 EXIT /B 1 + +CD .. +RMDIR /S /Q cmake_msvc2017 + +:: ---------------------------------------------------------------------------- +:: Bazel Visual Studio 15 2017 Win64 + +SET BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2017\BuildTools\VC +%BAZEL_EXE% test ... ^ + --compilation_mode=dbg ^ + --copt=/std:c++14 ^ + --copt=/WX ^ + --features=external_include_paths ^ + --keep_going ^ + --test_output=errors ^ + --test_tag_filters=-no_test_msvc2017 +IF %errorlevel% neq 0 EXIT /B 1 diff --git a/3rdparty/googletest-1.13.0/docs/_config.yml b/3rdparty/googletest-1.13.0/docs/_config.yml new file mode 100644 index 0000000000000000000000000000000000000000..d12867eab6b6872489002b56cf5c4115388fb1aa --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/_config.yml @@ -0,0 +1 @@ +title: GoogleTest diff --git a/3rdparty/googletest-1.13.0/docs/_data/navigation.yml b/3rdparty/googletest-1.13.0/docs/_data/navigation.yml new file mode 100644 index 0000000000000000000000000000000000000000..9f3332708eac165cd1fe2516f2b2cb855c5a32ef --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/_data/navigation.yml @@ -0,0 +1,43 @@ +nav: +- section: "Get Started" + items: + - title: "Supported Platforms" + url: "/platforms.html" + - title: "Quickstart: Bazel" + url: "/quickstart-bazel.html" + - title: "Quickstart: CMake" + url: "/quickstart-cmake.html" +- section: "Guides" + items: + - title: "GoogleTest Primer" + url: "/primer.html" + - title: "Advanced Topics" + url: "/advanced.html" + - title: "Mocking for Dummies" + url: "/gmock_for_dummies.html" + - title: "Mocking Cookbook" + url: "/gmock_cook_book.html" + - title: "Mocking Cheat Sheet" + url: "/gmock_cheat_sheet.html" +- section: "References" + items: + - title: "Testing Reference" + url: "/reference/testing.html" + - title: "Mocking Reference" + url: "/reference/mocking.html" + - title: "Assertions" + url: "/reference/assertions.html" + - title: "Matchers" + url: "/reference/matchers.html" + - title: "Actions" + url: "/reference/actions.html" + - title: "Testing FAQ" + url: "/faq.html" + - title: "Mocking FAQ" + url: "/gmock_faq.html" + - title: "Code Samples" + url: "/samples.html" + - title: "Using pkg-config" + url: "/pkgconfig.html" + - title: "Community Documentation" + url: "/community_created_documentation.html" diff --git a/3rdparty/googletest-1.13.0/docs/_layouts/default.html b/3rdparty/googletest-1.13.0/docs/_layouts/default.html new file mode 100644 index 0000000000000000000000000000000000000000..c7f331b87d7ddd4102791fd4d5b4122bfb0dd4b3 --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/_layouts/default.html @@ -0,0 +1,58 @@ + + + + + + + +{% seo %} + + + + + + +
+
+ {{ content }} +
+ +
+ + + + diff --git a/3rdparty/googletest-1.13.0/docs/_sass/main.scss b/3rdparty/googletest-1.13.0/docs/_sass/main.scss new file mode 100644 index 0000000000000000000000000000000000000000..92edc877a592e877d037b769337f82568913a9d7 --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/_sass/main.scss @@ -0,0 +1,200 @@ +// Styles for GoogleTest docs website on GitHub Pages. +// Color variables are defined in +// https://github.com/pages-themes/primer/tree/master/_sass/primer-support/lib/variables + +$sidebar-width: 260px; + +body { + display: flex; + margin: 0; +} + +.sidebar { + background: $black; + color: $text-white; + flex-shrink: 0; + height: 100vh; + overflow: auto; + position: sticky; + top: 0; + width: $sidebar-width; +} + +.sidebar h1 { + font-size: 1.5em; +} + +.sidebar h2 { + color: $gray-light; + font-size: 0.8em; + font-weight: normal; + margin-bottom: 0.8em; + padding-left: 2.5em; + text-transform: uppercase; +} + +.sidebar .header { + background: $black; + padding: 2em; + position: sticky; + top: 0; + width: 100%; +} + +.sidebar .header a { + color: $text-white; + text-decoration: none; +} + +.sidebar .nav-toggle { + display: none; +} + +.sidebar .expander { + cursor: pointer; + display: none; + height: 3em; + position: absolute; + right: 1em; + top: 1.5em; + width: 3em; +} + +.sidebar .expander .arrow { + border: solid $white; + border-width: 0 3px 3px 0; + display: block; + height: 0.7em; + margin: 1em auto; + transform: rotate(45deg); + transition: transform 0.5s; + width: 0.7em; +} + +.sidebar nav { + width: 100%; +} + +.sidebar nav ul { + list-style-type: none; + margin-bottom: 1em; + padding: 0; + + &:last-child { + margin-bottom: 2em; + } + + a { + text-decoration: none; + } + + li { + color: $text-white; + padding-left: 2em; + text-decoration: none; + } + + li.active { + background: $border-gray-darker; + font-weight: bold; + } + + li:hover { + background: $border-gray-darker; + } +} + +.main { + background-color: $bg-gray; + width: calc(100% - #{$sidebar-width}); +} + +.main .main-inner { + background-color: $white; + padding: 2em; +} + +.main .footer { + margin: 0; + padding: 2em; +} + +.main table th { + text-align: left; +} + +.main .callout { + border-left: 0.25em solid $white; + padding: 1em; + + a { + text-decoration: underline; + } + + &.important { + background-color: $bg-yellow-light; + border-color: $bg-yellow; + color: $black; + } + + &.note { + background-color: $bg-blue-light; + border-color: $text-blue; + color: $text-blue; + } + + &.tip { + background-color: $green-000; + border-color: $green-700; + color: $green-700; + } + + &.warning { + background-color: $red-000; + border-color: $text-red; + color: $text-red; + } +} + +.main .good pre { + background-color: $bg-green-light; +} + +.main .bad pre { + background-color: $red-000; +} + +@media all and (max-width: 768px) { + body { + flex-direction: column; + } + + .sidebar { + height: auto; + position: relative; + width: 100%; + } + + .sidebar .expander { + display: block; + } + + .sidebar nav { + height: 0; + overflow: hidden; + } + + .sidebar .nav-toggle:checked { + & ~ nav { + height: auto; + } + + & + .expander .arrow { + transform: rotate(-135deg); + } + } + + .main { + width: 100%; + } +} diff --git a/3rdparty/googletest-1.13.0/docs/advanced.md b/3rdparty/googletest-1.13.0/docs/advanced.md new file mode 100644 index 0000000000000000000000000000000000000000..f16382fe04fcd6fe91edb48e7e2748d3963a45e4 --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/advanced.md @@ -0,0 +1,2407 @@ +# Advanced googletest Topics + +## Introduction + +Now that you have read the [googletest Primer](primer.md) and learned how to +write tests using googletest, it's time to learn some new tricks. This document +will show you more assertions as well as how to construct complex failure +messages, propagate fatal failures, reuse and speed up your test fixtures, and +use various flags with your tests. + +## More Assertions + +This section covers some less frequently used, but still significant, +assertions. + +### Explicit Success and Failure + +See [Explicit Success and Failure](reference/assertions.md#success-failure) in +the Assertions Reference. + +### Exception Assertions + +See [Exception Assertions](reference/assertions.md#exceptions) in the Assertions +Reference. + +### Predicate Assertions for Better Error Messages + +Even though googletest has a rich set of assertions, they can never be complete, +as it's impossible (nor a good idea) to anticipate all scenarios a user might +run into. Therefore, sometimes a user has to use `EXPECT_TRUE()` to check a +complex expression, for lack of a better macro. This has the problem of not +showing you the values of the parts of the expression, making it hard to +understand what went wrong. As a workaround, some users choose to construct the +failure message by themselves, streaming it into `EXPECT_TRUE()`. However, this +is awkward especially when the expression has side-effects or is expensive to +evaluate. + +googletest gives you three different options to solve this problem: + +#### Using an Existing Boolean Function + +If you already have a function or functor that returns `bool` (or a type that +can be implicitly converted to `bool`), you can use it in a *predicate +assertion* to get the function arguments printed for free. See +[`EXPECT_PRED*`](reference/assertions.md#EXPECT_PRED) in the Assertions +Reference for details. + +#### Using a Function That Returns an AssertionResult + +While `EXPECT_PRED*()` and friends are handy for a quick job, the syntax is not +satisfactory: you have to use different macros for different arities, and it +feels more like Lisp than C++. The `::testing::AssertionResult` class solves +this problem. + +An `AssertionResult` object represents the result of an assertion (whether it's +a success or a failure, and an associated message). You can create an +`AssertionResult` using one of these factory functions: + +```c++ +namespace testing { + +// Returns an AssertionResult object to indicate that an assertion has +// succeeded. +AssertionResult AssertionSuccess(); + +// Returns an AssertionResult object to indicate that an assertion has +// failed. +AssertionResult AssertionFailure(); + +} +``` + +You can then use the `<<` operator to stream messages to the `AssertionResult` +object. + +To provide more readable messages in Boolean assertions (e.g. `EXPECT_TRUE()`), +write a predicate function that returns `AssertionResult` instead of `bool`. For +example, if you define `IsEven()` as: + +```c++ +testing::AssertionResult IsEven(int n) { + if ((n % 2) == 0) + return testing::AssertionSuccess(); + else + return testing::AssertionFailure() << n << " is odd"; +} +``` + +instead of: + +```c++ +bool IsEven(int n) { + return (n % 2) == 0; +} +``` + +the failed assertion `EXPECT_TRUE(IsEven(Fib(4)))` will print: + +```none +Value of: IsEven(Fib(4)) + Actual: false (3 is odd) +Expected: true +``` + +instead of a more opaque + +```none +Value of: IsEven(Fib(4)) + Actual: false +Expected: true +``` + +If you want informative messages in `EXPECT_FALSE` and `ASSERT_FALSE` as well +(one third of Boolean assertions in the Google code base are negative ones), and +are fine with making the predicate slower in the success case, you can supply a +success message: + +```c++ +testing::AssertionResult IsEven(int n) { + if ((n % 2) == 0) + return testing::AssertionSuccess() << n << " is even"; + else + return testing::AssertionFailure() << n << " is odd"; +} +``` + +Then the statement `EXPECT_FALSE(IsEven(Fib(6)))` will print + +```none + Value of: IsEven(Fib(6)) + Actual: true (8 is even) + Expected: false +``` + +#### Using a Predicate-Formatter + +If you find the default message generated by +[`EXPECT_PRED*`](reference/assertions.md#EXPECT_PRED) and +[`EXPECT_TRUE`](reference/assertions.md#EXPECT_TRUE) unsatisfactory, or some +arguments to your predicate do not support streaming to `ostream`, you can +instead use *predicate-formatter assertions* to *fully* customize how the +message is formatted. See +[`EXPECT_PRED_FORMAT*`](reference/assertions.md#EXPECT_PRED_FORMAT) in the +Assertions Reference for details. + +### Floating-Point Comparison + +See [Floating-Point Comparison](reference/assertions.md#floating-point) in the +Assertions Reference. + +#### Floating-Point Predicate-Format Functions + +Some floating-point operations are useful, but not that often used. In order to +avoid an explosion of new macros, we provide them as predicate-format functions +that can be used in the predicate assertion macro +[`EXPECT_PRED_FORMAT2`](reference/assertions.md#EXPECT_PRED_FORMAT), for +example: + +```c++ +using ::testing::FloatLE; +using ::testing::DoubleLE; +... +EXPECT_PRED_FORMAT2(FloatLE, val1, val2); +EXPECT_PRED_FORMAT2(DoubleLE, val1, val2); +``` + +The above code verifies that `val1` is less than, or approximately equal to, +`val2`. + +### Asserting Using gMock Matchers + +See [`EXPECT_THAT`](reference/assertions.md#EXPECT_THAT) in the Assertions +Reference. + +### More String Assertions + +(Please read the [previous](#asserting-using-gmock-matchers) section first if +you haven't.) + +You can use the gMock [string matchers](reference/matchers.md#string-matchers) +with [`EXPECT_THAT`](reference/assertions.md#EXPECT_THAT) to do more string +comparison tricks (sub-string, prefix, suffix, regular expression, and etc). For +example, + +```c++ +using ::testing::HasSubstr; +using ::testing::MatchesRegex; +... + ASSERT_THAT(foo_string, HasSubstr("needle")); + EXPECT_THAT(bar_string, MatchesRegex("\\w*\\d+")); +``` + +### Windows HRESULT assertions + +See [Windows HRESULT Assertions](reference/assertions.md#HRESULT) in the +Assertions Reference. + +### Type Assertions + +You can call the function + +```c++ +::testing::StaticAssertTypeEq(); +``` + +to assert that types `T1` and `T2` are the same. The function does nothing if +the assertion is satisfied. If the types are different, the function call will +fail to compile, the compiler error message will say that `T1 and T2 are not the +same type` and most likely (depending on the compiler) show you the actual +values of `T1` and `T2`. This is mainly useful inside template code. + +**Caveat**: When used inside a member function of a class template or a function +template, `StaticAssertTypeEq()` is effective only if the function is +instantiated. For example, given: + +```c++ +template class Foo { + public: + void Bar() { testing::StaticAssertTypeEq(); } +}; +``` + +the code: + +```c++ +void Test1() { Foo foo; } +``` + +will not generate a compiler error, as `Foo::Bar()` is never actually +instantiated. Instead, you need: + +```c++ +void Test2() { Foo foo; foo.Bar(); } +``` + +to cause a compiler error. + +### Assertion Placement + +You can use assertions in any C++ function. In particular, it doesn't have to be +a method of the test fixture class. The one constraint is that assertions that +generate a fatal failure (`FAIL*` and `ASSERT_*`) can only be used in +void-returning functions. This is a consequence of Google's not using +exceptions. By placing it in a non-void function you'll get a confusing compile +error like `"error: void value not ignored as it ought to be"` or `"cannot +initialize return object of type 'bool' with an rvalue of type 'void'"` or +`"error: no viable conversion from 'void' to 'string'"`. + +If you need to use fatal assertions in a function that returns non-void, one +option is to make the function return the value in an out parameter instead. For +example, you can rewrite `T2 Foo(T1 x)` to `void Foo(T1 x, T2* result)`. You +need to make sure that `*result` contains some sensible value even when the +function returns prematurely. As the function now returns `void`, you can use +any assertion inside of it. + +If changing the function's type is not an option, you should just use assertions +that generate non-fatal failures, such as `ADD_FAILURE*` and `EXPECT_*`. + +{: .callout .note} +NOTE: Constructors and destructors are not considered void-returning functions, +according to the C++ language specification, and so you may not use fatal +assertions in them; you'll get a compilation error if you try. Instead, either +call `abort` and crash the entire test executable, or put the fatal assertion in +a `SetUp`/`TearDown` function; see +[constructor/destructor vs. `SetUp`/`TearDown`](faq.md#CtorVsSetUp) + +{: .callout .warning} +WARNING: A fatal assertion in a helper function (private void-returning method) +called from a constructor or destructor does not terminate the current test, as +your intuition might suggest: it merely returns from the constructor or +destructor early, possibly leaving your object in a partially-constructed or +partially-destructed state! You almost certainly want to `abort` or use +`SetUp`/`TearDown` instead. + +## Skipping test execution + +Related to the assertions `SUCCEED()` and `FAIL()`, you can prevent further test +execution at runtime with the `GTEST_SKIP()` macro. This is useful when you need +to check for preconditions of the system under test during runtime and skip +tests in a meaningful way. + +`GTEST_SKIP()` can be used in individual test cases or in the `SetUp()` methods +of classes derived from either `::testing::Environment` or `::testing::Test`. +For example: + +```c++ +TEST(SkipTest, DoesSkip) { + GTEST_SKIP() << "Skipping single test"; + EXPECT_EQ(0, 1); // Won't fail; it won't be executed +} + +class SkipFixture : public ::testing::Test { + protected: + void SetUp() override { + GTEST_SKIP() << "Skipping all tests for this fixture"; + } +}; + +// Tests for SkipFixture won't be executed. +TEST_F(SkipFixture, SkipsOneTest) { + EXPECT_EQ(5, 7); // Won't fail +} +``` + +As with assertion macros, you can stream a custom message into `GTEST_SKIP()`. + +## Teaching googletest How to Print Your Values + +When a test assertion such as `EXPECT_EQ` fails, googletest prints the argument +values to help you debug. It does this using a user-extensible value printer. + +This printer knows how to print built-in C++ types, native arrays, STL +containers, and any type that supports the `<<` operator. For other types, it +prints the raw bytes in the value and hopes that you the user can figure it out. + +As mentioned earlier, the printer is *extensible*. That means you can teach it +to do a better job at printing your particular type than to dump the bytes. To +do that, define `<<` for your type: + +```c++ +#include + +namespace foo { + +class Bar { // We want googletest to be able to print instances of this. +... + // Create a free inline friend function. + friend std::ostream& operator<<(std::ostream& os, const Bar& bar) { + return os << bar.DebugString(); // whatever needed to print bar to os + } +}; + +// If you can't declare the function in the class it's important that the +// << operator is defined in the SAME namespace that defines Bar. C++'s look-up +// rules rely on that. +std::ostream& operator<<(std::ostream& os, const Bar& bar) { + return os << bar.DebugString(); // whatever needed to print bar to os +} + +} // namespace foo +``` + +Sometimes, this might not be an option: your team may consider it bad style to +have a `<<` operator for `Bar`, or `Bar` may already have a `<<` operator that +doesn't do what you want (and you cannot change it). If so, you can instead +define a `PrintTo()` function like this: + +```c++ +#include + +namespace foo { + +class Bar { + ... + friend void PrintTo(const Bar& bar, std::ostream* os) { + *os << bar.DebugString(); // whatever needed to print bar to os + } +}; + +// If you can't declare the function in the class it's important that PrintTo() +// is defined in the SAME namespace that defines Bar. C++'s look-up rules rely +// on that. +void PrintTo(const Bar& bar, std::ostream* os) { + *os << bar.DebugString(); // whatever needed to print bar to os +} + +} // namespace foo +``` + +If you have defined both `<<` and `PrintTo()`, the latter will be used when +googletest is concerned. This allows you to customize how the value appears in +googletest's output without affecting code that relies on the behavior of its +`<<` operator. + +If you want to print a value `x` using googletest's value printer yourself, just +call `::testing::PrintToString(x)`, which returns an `std::string`: + +```c++ +vector > bar_ints = GetBarIntVector(); + +EXPECT_TRUE(IsCorrectBarIntVector(bar_ints)) + << "bar_ints = " << testing::PrintToString(bar_ints); +``` + +## Death Tests + +In many applications, there are assertions that can cause application failure if +a condition is not met. These consistency checks, which ensure that the program +is in a known good state, are there to fail at the earliest possible time after +some program state is corrupted. If the assertion checks the wrong condition, +then the program may proceed in an erroneous state, which could lead to memory +corruption, security holes, or worse. Hence it is vitally important to test that +such assertion statements work as expected. + +Since these precondition checks cause the processes to die, we call such tests +_death tests_. More generally, any test that checks that a program terminates +(except by throwing an exception) in an expected fashion is also a death test. + +Note that if a piece of code throws an exception, we don't consider it "death" +for the purpose of death tests, as the caller of the code could catch the +exception and avoid the crash. If you want to verify exceptions thrown by your +code, see [Exception Assertions](#ExceptionAssertions). + +If you want to test `EXPECT_*()/ASSERT_*()` failures in your test code, see +["Catching" Failures](#catching-failures). + +### How to Write a Death Test + +GoogleTest provides assertion macros to support death tests. See +[Death Assertions](reference/assertions.md#death) in the Assertions Reference +for details. + +To write a death test, simply use one of the macros inside your test function. +For example, + +```c++ +TEST(MyDeathTest, Foo) { + // This death test uses a compound statement. + ASSERT_DEATH({ + int n = 5; + Foo(&n); + }, "Error on line .* of Foo()"); +} + +TEST(MyDeathTest, NormalExit) { + EXPECT_EXIT(NormalExit(), testing::ExitedWithCode(0), "Success"); +} + +TEST(MyDeathTest, KillProcess) { + EXPECT_EXIT(KillProcess(), testing::KilledBySignal(SIGKILL), + "Sending myself unblockable signal"); +} +``` + +verifies that: + +* calling `Foo(5)` causes the process to die with the given error message, +* calling `NormalExit()` causes the process to print `"Success"` to stderr and + exit with exit code 0, and +* calling `KillProcess()` kills the process with signal `SIGKILL`. + +The test function body may contain other assertions and statements as well, if +necessary. + +Note that a death test only cares about three things: + +1. does `statement` abort or exit the process? +2. (in the case of `ASSERT_EXIT` and `EXPECT_EXIT`) does the exit status + satisfy `predicate`? Or (in the case of `ASSERT_DEATH` and `EXPECT_DEATH`) + is the exit status non-zero? And +3. does the stderr output match `matcher`? + +In particular, if `statement` generates an `ASSERT_*` or `EXPECT_*` failure, it +will **not** cause the death test to fail, as googletest assertions don't abort +the process. + +### Death Test Naming + +{: .callout .important} +IMPORTANT: We strongly recommend you to follow the convention of naming your +**test suite** (not test) `*DeathTest` when it contains a death test, as +demonstrated in the above example. The +[Death Tests And Threads](#death-tests-and-threads) section below explains why. + +If a test fixture class is shared by normal tests and death tests, you can use +`using` or `typedef` to introduce an alias for the fixture class and avoid +duplicating its code: + +```c++ +class FooTest : public testing::Test { ... }; + +using FooDeathTest = FooTest; + +TEST_F(FooTest, DoesThis) { + // normal test +} + +TEST_F(FooDeathTest, DoesThat) { + // death test +} +``` + +### Regular Expression Syntax + +When built with Bazel and using Abseil, googletest uses the +[RE2](https://github.com/google/re2/wiki/Syntax) syntax. Otherwise, for POSIX +systems (Linux, Cygwin, Mac), googletest uses the +[POSIX extended regular expression](http://www.opengroup.org/onlinepubs/009695399/basedefs/xbd_chap09.html#tag_09_04) +syntax. To learn about POSIX syntax, you may want to read this +[Wikipedia entry](http://en.wikipedia.org/wiki/Regular_expression#POSIX_extended). + +On Windows, googletest uses its own simple regular expression implementation. It +lacks many features. For example, we don't support union (`"x|y"`), grouping +(`"(xy)"`), brackets (`"[xy]"`), and repetition count (`"x{5,7}"`), among +others. Below is what we do support (`A` denotes a literal character, period +(`.`), or a single `\\ ` escape sequence; `x` and `y` denote regular +expressions.): + +Expression | Meaning +---------- | -------------------------------------------------------------- +`c` | matches any literal character `c` +`\\d` | matches any decimal digit +`\\D` | matches any character that's not a decimal digit +`\\f` | matches `\f` +`\\n` | matches `\n` +`\\r` | matches `\r` +`\\s` | matches any ASCII whitespace, including `\n` +`\\S` | matches any character that's not a whitespace +`\\t` | matches `\t` +`\\v` | matches `\v` +`\\w` | matches any letter, `_`, or decimal digit +`\\W` | matches any character that `\\w` doesn't match +`\\c` | matches any literal character `c`, which must be a punctuation +`.` | matches any single character except `\n` +`A?` | matches 0 or 1 occurrences of `A` +`A*` | matches 0 or many occurrences of `A` +`A+` | matches 1 or many occurrences of `A` +`^` | matches the beginning of a string (not that of each line) +`$` | matches the end of a string (not that of each line) +`xy` | matches `x` followed by `y` + +To help you determine which capability is available on your system, googletest +defines macros to govern which regular expression it is using. The macros are: +`GTEST_USES_SIMPLE_RE=1` or `GTEST_USES_POSIX_RE=1`. If you want your death +tests to work in all cases, you can either `#if` on these macros or use the more +limited syntax only. + +### How It Works + +See [Death Assertions](reference/assertions.md#death) in the Assertions +Reference. + +### Death Tests And Threads + +The reason for the two death test styles has to do with thread safety. Due to +well-known problems with forking in the presence of threads, death tests should +be run in a single-threaded context. Sometimes, however, it isn't feasible to +arrange that kind of environment. For example, statically-initialized modules +may start threads before main is ever reached. Once threads have been created, +it may be difficult or impossible to clean them up. + +googletest has three features intended to raise awareness of threading issues. + +1. A warning is emitted if multiple threads are running when a death test is + encountered. +2. Test suites with a name ending in "DeathTest" are run before all other + tests. +3. It uses `clone()` instead of `fork()` to spawn the child process on Linux + (`clone()` is not available on Cygwin and Mac), as `fork()` is more likely + to cause the child to hang when the parent process has multiple threads. + +It's perfectly fine to create threads inside a death test statement; they are +executed in a separate process and cannot affect the parent. + +### Death Test Styles + +The "threadsafe" death test style was introduced in order to help mitigate the +risks of testing in a possibly multithreaded environment. It trades increased +test execution time (potentially dramatically so) for improved thread safety. + +The automated testing framework does not set the style flag. You can choose a +particular style of death tests by setting the flag programmatically: + +```c++ +GTEST_FLAG_SET(death_test_style, "threadsafe") +``` + +You can do this in `main()` to set the style for all death tests in the binary, +or in individual tests. Recall that flags are saved before running each test and +restored afterwards, so you need not do that yourself. For example: + +```c++ +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + GTEST_FLAG_SET(death_test_style, "fast"); + return RUN_ALL_TESTS(); +} + +TEST(MyDeathTest, TestOne) { + GTEST_FLAG_SET(death_test_style, "threadsafe"); + // This test is run in the "threadsafe" style: + ASSERT_DEATH(ThisShouldDie(), ""); +} + +TEST(MyDeathTest, TestTwo) { + // This test is run in the "fast" style: + ASSERT_DEATH(ThisShouldDie(), ""); +} +``` + +### Caveats + +The `statement` argument of `ASSERT_EXIT()` can be any valid C++ statement. If +it leaves the current function via a `return` statement or by throwing an +exception, the death test is considered to have failed. Some googletest macros +may return from the current function (e.g. `ASSERT_TRUE()`), so be sure to avoid +them in `statement`. + +Since `statement` runs in the child process, any in-memory side effect (e.g. +modifying a variable, releasing memory, etc) it causes will *not* be observable +in the parent process. In particular, if you release memory in a death test, +your program will fail the heap check as the parent process will never see the +memory reclaimed. To solve this problem, you can + +1. try not to free memory in a death test; +2. free the memory again in the parent process; or +3. do not use the heap checker in your program. + +Due to an implementation detail, you cannot place multiple death test assertions +on the same line; otherwise, compilation will fail with an unobvious error +message. + +Despite the improved thread safety afforded by the "threadsafe" style of death +test, thread problems such as deadlock are still possible in the presence of +handlers registered with `pthread_atfork(3)`. + +## Using Assertions in Sub-routines + +{: .callout .note} +Note: If you want to put a series of test assertions in a subroutine to check +for a complex condition, consider using +[a custom GMock matcher](gmock_cook_book.md#NewMatchers) instead. This lets you +provide a more readable error message in case of failure and avoid all of the +issues described below. + +### Adding Traces to Assertions + +If a test sub-routine is called from several places, when an assertion inside it +fails, it can be hard to tell which invocation of the sub-routine the failure is +from. You can alleviate this problem using extra logging or custom failure +messages, but that usually clutters up your tests. A better solution is to use +the `SCOPED_TRACE` macro or the `ScopedTrace` utility: + +```c++ +SCOPED_TRACE(message); +``` + +```c++ +ScopedTrace trace("file_path", line_number, message); +``` + +where `message` can be anything streamable to `std::ostream`. `SCOPED_TRACE` +macro will cause the current file name, line number, and the given message to be +added in every failure message. `ScopedTrace` accepts explicit file name and +line number in arguments, which is useful for writing test helpers. The effect +will be undone when the control leaves the current lexical scope. + +For example, + +```c++ +10: void Sub1(int n) { +11: EXPECT_EQ(Bar(n), 1); +12: EXPECT_EQ(Bar(n + 1), 2); +13: } +14: +15: TEST(FooTest, Bar) { +16: { +17: SCOPED_TRACE("A"); // This trace point will be included in +18: // every failure in this scope. +19: Sub1(1); +20: } +21: // Now it won't. +22: Sub1(9); +23: } +``` + +could result in messages like these: + +```none +path/to/foo_test.cc:11: Failure +Value of: Bar(n) +Expected: 1 + Actual: 2 +Google Test trace: +path/to/foo_test.cc:17: A + +path/to/foo_test.cc:12: Failure +Value of: Bar(n + 1) +Expected: 2 + Actual: 3 +``` + +Without the trace, it would've been difficult to know which invocation of +`Sub1()` the two failures come from respectively. (You could add an extra +message to each assertion in `Sub1()` to indicate the value of `n`, but that's +tedious.) + +Some tips on using `SCOPED_TRACE`: + +1. With a suitable message, it's often enough to use `SCOPED_TRACE` at the + beginning of a sub-routine, instead of at each call site. +2. When calling sub-routines inside a loop, make the loop iterator part of the + message in `SCOPED_TRACE` such that you can know which iteration the failure + is from. +3. Sometimes the line number of the trace point is enough for identifying the + particular invocation of a sub-routine. In this case, you don't have to + choose a unique message for `SCOPED_TRACE`. You can simply use `""`. +4. You can use `SCOPED_TRACE` in an inner scope when there is one in the outer + scope. In this case, all active trace points will be included in the failure + messages, in reverse order they are encountered. +5. The trace dump is clickable in Emacs - hit `return` on a line number and + you'll be taken to that line in the source file! + +### Propagating Fatal Failures + +A common pitfall when using `ASSERT_*` and `FAIL*` is not understanding that +when they fail they only abort the _current function_, not the entire test. For +example, the following test will segfault: + +```c++ +void Subroutine() { + // Generates a fatal failure and aborts the current function. + ASSERT_EQ(1, 2); + + // The following won't be executed. + ... +} + +TEST(FooTest, Bar) { + Subroutine(); // The intended behavior is for the fatal failure + // in Subroutine() to abort the entire test. + + // The actual behavior: the function goes on after Subroutine() returns. + int* p = nullptr; + *p = 3; // Segfault! +} +``` + +To alleviate this, googletest provides three different solutions. You could use +either exceptions, the `(ASSERT|EXPECT)_NO_FATAL_FAILURE` assertions or the +`HasFatalFailure()` function. They are described in the following two +subsections. + +#### Asserting on Subroutines with an exception + +The following code can turn ASSERT-failure into an exception: + +```c++ +class ThrowListener : public testing::EmptyTestEventListener { + void OnTestPartResult(const testing::TestPartResult& result) override { + if (result.type() == testing::TestPartResult::kFatalFailure) { + throw testing::AssertionException(result); + } + } +}; +int main(int argc, char** argv) { + ... + testing::UnitTest::GetInstance()->listeners().Append(new ThrowListener); + return RUN_ALL_TESTS(); +} +``` + +This listener should be added after other listeners if you have any, otherwise +they won't see failed `OnTestPartResult`. + +#### Asserting on Subroutines + +As shown above, if your test calls a subroutine that has an `ASSERT_*` failure +in it, the test will continue after the subroutine returns. This may not be what +you want. + +Often people want fatal failures to propagate like exceptions. For that +googletest offers the following macros: + +Fatal assertion | Nonfatal assertion | Verifies +------------------------------------- | ------------------------------------- | -------- +`ASSERT_NO_FATAL_FAILURE(statement);` | `EXPECT_NO_FATAL_FAILURE(statement);` | `statement` doesn't generate any new fatal failures in the current thread. + +Only failures in the thread that executes the assertion are checked to determine +the result of this type of assertions. If `statement` creates new threads, +failures in these threads are ignored. + +Examples: + +```c++ +ASSERT_NO_FATAL_FAILURE(Foo()); + +int i; +EXPECT_NO_FATAL_FAILURE({ + i = Bar(); +}); +``` + +Assertions from multiple threads are currently not supported on Windows. + +#### Checking for Failures in the Current Test + +`HasFatalFailure()` in the `::testing::Test` class returns `true` if an +assertion in the current test has suffered a fatal failure. This allows +functions to catch fatal failures in a sub-routine and return early. + +```c++ +class Test { + public: + ... + static bool HasFatalFailure(); +}; +``` + +The typical usage, which basically simulates the behavior of a thrown exception, +is: + +```c++ +TEST(FooTest, Bar) { + Subroutine(); + // Aborts if Subroutine() had a fatal failure. + if (HasFatalFailure()) return; + + // The following won't be executed. + ... +} +``` + +If `HasFatalFailure()` is used outside of `TEST()` , `TEST_F()` , or a test +fixture, you must add the `::testing::Test::` prefix, as in: + +```c++ +if (testing::Test::HasFatalFailure()) return; +``` + +Similarly, `HasNonfatalFailure()` returns `true` if the current test has at +least one non-fatal failure, and `HasFailure()` returns `true` if the current +test has at least one failure of either kind. + +## Logging Additional Information + +In your test code, you can call `RecordProperty("key", value)` to log additional +information, where `value` can be either a string or an `int`. The *last* value +recorded for a key will be emitted to the +[XML output](#generating-an-xml-report) if you specify one. For example, the +test + +```c++ +TEST_F(WidgetUsageTest, MinAndMaxWidgets) { + RecordProperty("MaximumWidgets", ComputeMaxUsage()); + RecordProperty("MinimumWidgets", ComputeMinUsage()); +} +``` + +will output XML like this: + +```xml + ... + + ... +``` + +{: .callout .note} +> NOTE: +> +> * `RecordProperty()` is a static member of the `Test` class. Therefore it +> needs to be prefixed with `::testing::Test::` if used outside of the +> `TEST` body and the test fixture class. +> * *`key`* must be a valid XML attribute name, and cannot conflict with the +> ones already used by googletest (`name`, `status`, `time`, `classname`, +> `type_param`, and `value_param`). +> * Calling `RecordProperty()` outside of the lifespan of a test is allowed. +> If it's called outside of a test but between a test suite's +> `SetUpTestSuite()` and `TearDownTestSuite()` methods, it will be +> attributed to the XML element for the test suite. If it's called outside +> of all test suites (e.g. in a test environment), it will be attributed to +> the top-level XML element. + +## Sharing Resources Between Tests in the Same Test Suite + +googletest creates a new test fixture object for each test in order to make +tests independent and easier to debug. However, sometimes tests use resources +that are expensive to set up, making the one-copy-per-test model prohibitively +expensive. + +If the tests don't change the resource, there's no harm in their sharing a +single resource copy. So, in addition to per-test set-up/tear-down, googletest +also supports per-test-suite set-up/tear-down. To use it: + +1. In your test fixture class (say `FooTest` ), declare as `static` some member + variables to hold the shared resources. +2. Outside your test fixture class (typically just below it), define those + member variables, optionally giving them initial values. +3. In the same test fixture class, define a `static void SetUpTestSuite()` + function (remember not to spell it as **`SetupTestSuite`** with a small + `u`!) to set up the shared resources and a `static void TearDownTestSuite()` + function to tear them down. + +That's it! googletest automatically calls `SetUpTestSuite()` before running the +*first test* in the `FooTest` test suite (i.e. before creating the first +`FooTest` object), and calls `TearDownTestSuite()` after running the *last test* +in it (i.e. after deleting the last `FooTest` object). In between, the tests can +use the shared resources. + +Remember that the test order is undefined, so your code can't depend on a test +preceding or following another. Also, the tests must either not modify the state +of any shared resource, or, if they do modify the state, they must restore the +state to its original value before passing control to the next test. + +Note that `SetUpTestSuite()` may be called multiple times for a test fixture +class that has derived classes, so you should not expect code in the function +body to be run only once. Also, derived classes still have access to shared +resources defined as static members, so careful consideration is needed when +managing shared resources to avoid memory leaks. + +Here's an example of per-test-suite set-up and tear-down: + +```c++ +class FooTest : public testing::Test { + protected: + // Per-test-suite set-up. + // Called before the first test in this test suite. + // Can be omitted if not needed. + static void SetUpTestSuite() { + // Avoid reallocating static objects if called in subclasses of FooTest. + if (shared_resource_ == nullptr) { + shared_resource_ = new ...; + } + } + + // Per-test-suite tear-down. + // Called after the last test in this test suite. + // Can be omitted if not needed. + static void TearDownTestSuite() { + delete shared_resource_; + shared_resource_ = nullptr; + } + + // You can define per-test set-up logic as usual. + void SetUp() override { ... } + + // You can define per-test tear-down logic as usual. + void TearDown() override { ... } + + // Some expensive resource shared by all tests. + static T* shared_resource_; +}; + +T* FooTest::shared_resource_ = nullptr; + +TEST_F(FooTest, Test1) { + ... you can refer to shared_resource_ here ... +} + +TEST_F(FooTest, Test2) { + ... you can refer to shared_resource_ here ... +} +``` + +{: .callout .note} +NOTE: Though the above code declares `SetUpTestSuite()` protected, it may +sometimes be necessary to declare it public, such as when using it with +`TEST_P`. + +## Global Set-Up and Tear-Down + +Just as you can do set-up and tear-down at the test level and the test suite +level, you can also do it at the test program level. Here's how. + +First, you subclass the `::testing::Environment` class to define a test +environment, which knows how to set-up and tear-down: + +```c++ +class Environment : public ::testing::Environment { + public: + ~Environment() override {} + + // Override this to define how to set up the environment. + void SetUp() override {} + + // Override this to define how to tear down the environment. + void TearDown() override {} +}; +``` + +Then, you register an instance of your environment class with googletest by +calling the `::testing::AddGlobalTestEnvironment()` function: + +```c++ +Environment* AddGlobalTestEnvironment(Environment* env); +``` + +Now, when `RUN_ALL_TESTS()` is called, it first calls the `SetUp()` method of +each environment object, then runs the tests if none of the environments +reported fatal failures and `GTEST_SKIP()` was not called. `RUN_ALL_TESTS()` +always calls `TearDown()` with each environment object, regardless of whether or +not the tests were run. + +It's OK to register multiple environment objects. In this suite, their `SetUp()` +will be called in the order they are registered, and their `TearDown()` will be +called in the reverse order. + +Note that googletest takes ownership of the registered environment objects. +Therefore **do not delete them** by yourself. + +You should call `AddGlobalTestEnvironment()` before `RUN_ALL_TESTS()` is called, +probably in `main()`. If you use `gtest_main`, you need to call this before +`main()` starts for it to take effect. One way to do this is to define a global +variable like this: + +```c++ +testing::Environment* const foo_env = + testing::AddGlobalTestEnvironment(new FooEnvironment); +``` + +However, we strongly recommend you to write your own `main()` and call +`AddGlobalTestEnvironment()` there, as relying on initialization of global +variables makes the code harder to read and may cause problems when you register +multiple environments from different translation units and the environments have +dependencies among them (remember that the compiler doesn't guarantee the order +in which global variables from different translation units are initialized). + +## Value-Parameterized Tests + +*Value-parameterized tests* allow you to test your code with different +parameters without writing multiple copies of the same test. This is useful in a +number of situations, for example: + +* You have a piece of code whose behavior is affected by one or more + command-line flags. You want to make sure your code performs correctly for + various values of those flags. +* You want to test different implementations of an OO interface. +* You want to test your code over various inputs (a.k.a. data-driven testing). + This feature is easy to abuse, so please exercise your good sense when doing + it! + +### How to Write Value-Parameterized Tests + +To write value-parameterized tests, first you should define a fixture class. It +must be derived from both `testing::Test` and `testing::WithParamInterface` +(the latter is a pure interface), where `T` is the type of your parameter +values. For convenience, you can just derive the fixture class from +`testing::TestWithParam`, which itself is derived from both `testing::Test` +and `testing::WithParamInterface`. `T` can be any copyable type. If it's a +raw pointer, you are responsible for managing the lifespan of the pointed +values. + +{: .callout .note} +NOTE: If your test fixture defines `SetUpTestSuite()` or `TearDownTestSuite()` +they must be declared **public** rather than **protected** in order to use +`TEST_P`. + +```c++ +class FooTest : + public testing::TestWithParam { + // You can implement all the usual fixture class members here. + // To access the test parameter, call GetParam() from class + // TestWithParam. +}; + +// Or, when you want to add parameters to a pre-existing fixture class: +class BaseTest : public testing::Test { + ... +}; +class BarTest : public BaseTest, + public testing::WithParamInterface { + ... +}; +``` + +Then, use the `TEST_P` macro to define as many test patterns using this fixture +as you want. The `_P` suffix is for "parameterized" or "pattern", whichever you +prefer to think. + +```c++ +TEST_P(FooTest, DoesBlah) { + // Inside a test, access the test parameter with the GetParam() method + // of the TestWithParam class: + EXPECT_TRUE(foo.Blah(GetParam())); + ... +} + +TEST_P(FooTest, HasBlahBlah) { + ... +} +``` + +Finally, you can use the `INSTANTIATE_TEST_SUITE_P` macro to instantiate the +test suite with any set of parameters you want. GoogleTest defines a number of +functions for generating test parameters—see details at +[`INSTANTIATE_TEST_SUITE_P`](reference/testing.md#INSTANTIATE_TEST_SUITE_P) in +the Testing Reference. + +For example, the following statement will instantiate tests from the `FooTest` +test suite each with parameter values `"meeny"`, `"miny"`, and `"moe"` using the +[`Values`](reference/testing.md#param-generators) parameter generator: + +```c++ +INSTANTIATE_TEST_SUITE_P(MeenyMinyMoe, + FooTest, + testing::Values("meeny", "miny", "moe")); +``` + +{: .callout .note} +NOTE: The code above must be placed at global or namespace scope, not at +function scope. + +The first argument to `INSTANTIATE_TEST_SUITE_P` is a unique name for the +instantiation of the test suite. The next argument is the name of the test +pattern, and the last is the +[parameter generator](reference/testing.md#param-generators). + +The parameter generator expression is not evaluated until GoogleTest is +initialized (via `InitGoogleTest()`). Any prior initialization done in the +`main` function will be accessible from the parameter generator, for example, +the results of flag parsing. + +You can instantiate a test pattern more than once, so to distinguish different +instances of the pattern, the instantiation name is added as a prefix to the +actual test suite name. Remember to pick unique prefixes for different +instantiations. The tests from the instantiation above will have these names: + +* `MeenyMinyMoe/FooTest.DoesBlah/0` for `"meeny"` +* `MeenyMinyMoe/FooTest.DoesBlah/1` for `"miny"` +* `MeenyMinyMoe/FooTest.DoesBlah/2` for `"moe"` +* `MeenyMinyMoe/FooTest.HasBlahBlah/0` for `"meeny"` +* `MeenyMinyMoe/FooTest.HasBlahBlah/1` for `"miny"` +* `MeenyMinyMoe/FooTest.HasBlahBlah/2` for `"moe"` + +You can use these names in [`--gtest_filter`](#running-a-subset-of-the-tests). + +The following statement will instantiate all tests from `FooTest` again, each +with parameter values `"cat"` and `"dog"` using the +[`ValuesIn`](reference/testing.md#param-generators) parameter generator: + +```c++ +const char* pets[] = {"cat", "dog"}; +INSTANTIATE_TEST_SUITE_P(Pets, FooTest, testing::ValuesIn(pets)); +``` + +The tests from the instantiation above will have these names: + +* `Pets/FooTest.DoesBlah/0` for `"cat"` +* `Pets/FooTest.DoesBlah/1` for `"dog"` +* `Pets/FooTest.HasBlahBlah/0` for `"cat"` +* `Pets/FooTest.HasBlahBlah/1` for `"dog"` + +Please note that `INSTANTIATE_TEST_SUITE_P` will instantiate *all* tests in the +given test suite, whether their definitions come before or *after* the +`INSTANTIATE_TEST_SUITE_P` statement. + +Additionally, by default, every `TEST_P` without a corresponding +`INSTANTIATE_TEST_SUITE_P` causes a failing test in test suite +`GoogleTestVerification`. If you have a test suite where that omission is not an +error, for example it is in a library that may be linked in for other reasons or +where the list of test cases is dynamic and may be empty, then this check can be +suppressed by tagging the test suite: + +```c++ +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(FooTest); +``` + +You can see [sample7_unittest.cc] and [sample8_unittest.cc] for more examples. + +[sample7_unittest.cc]: https://github.com/google/googletest/blob/main/googletest/samples/sample7_unittest.cc "Parameterized Test example" +[sample8_unittest.cc]: https://github.com/google/googletest/blob/main/googletest/samples/sample8_unittest.cc "Parameterized Test example with multiple parameters" + +### Creating Value-Parameterized Abstract Tests + +In the above, we define and instantiate `FooTest` in the *same* source file. +Sometimes you may want to define value-parameterized tests in a library and let +other people instantiate them later. This pattern is known as *abstract tests*. +As an example of its application, when you are designing an interface you can +write a standard suite of abstract tests (perhaps using a factory function as +the test parameter) that all implementations of the interface are expected to +pass. When someone implements the interface, they can instantiate your suite to +get all the interface-conformance tests for free. + +To define abstract tests, you should organize your code like this: + +1. Put the definition of the parameterized test fixture class (e.g. `FooTest`) + in a header file, say `foo_param_test.h`. Think of this as *declaring* your + abstract tests. +2. Put the `TEST_P` definitions in `foo_param_test.cc`, which includes + `foo_param_test.h`. Think of this as *implementing* your abstract tests. + +Once they are defined, you can instantiate them by including `foo_param_test.h`, +invoking `INSTANTIATE_TEST_SUITE_P()`, and depending on the library target that +contains `foo_param_test.cc`. You can instantiate the same abstract test suite +multiple times, possibly in different source files. + +### Specifying Names for Value-Parameterized Test Parameters + +The optional last argument to `INSTANTIATE_TEST_SUITE_P()` allows the user to +specify a function or functor that generates custom test name suffixes based on +the test parameters. The function should accept one argument of type +`testing::TestParamInfo`, and return `std::string`. + +`testing::PrintToStringParamName` is a builtin test suffix generator that +returns the value of `testing::PrintToString(GetParam())`. It does not work for +`std::string` or C strings. + +{: .callout .note} +NOTE: test names must be non-empty, unique, and may only contain ASCII +alphanumeric characters. In particular, they +[should not contain underscores](faq.md#why-should-test-suite-names-and-test-names-not-contain-underscore) + +```c++ +class MyTestSuite : public testing::TestWithParam {}; + +TEST_P(MyTestSuite, MyTest) +{ + std::cout << "Example Test Param: " << GetParam() << std::endl; +} + +INSTANTIATE_TEST_SUITE_P(MyGroup, MyTestSuite, testing::Range(0, 10), + testing::PrintToStringParamName()); +``` + +Providing a custom functor allows for more control over test parameter name +generation, especially for types where the automatic conversion does not +generate helpful parameter names (e.g. strings as demonstrated above). The +following example illustrates this for multiple parameters, an enumeration type +and a string, and also demonstrates how to combine generators. It uses a lambda +for conciseness: + +```c++ +enum class MyType { MY_FOO = 0, MY_BAR = 1 }; + +class MyTestSuite : public testing::TestWithParam> { +}; + +INSTANTIATE_TEST_SUITE_P( + MyGroup, MyTestSuite, + testing::Combine( + testing::Values(MyType::MY_FOO, MyType::MY_BAR), + testing::Values("A", "B")), + [](const testing::TestParamInfo& info) { + std::string name = absl::StrCat( + std::get<0>(info.param) == MyType::MY_FOO ? "Foo" : "Bar", + std::get<1>(info.param)); + absl::c_replace_if(name, [](char c) { return !std::isalnum(c); }, '_'); + return name; + }); +``` + +## Typed Tests + +Suppose you have multiple implementations of the same interface and want to make +sure that all of them satisfy some common requirements. Or, you may have defined +several types that are supposed to conform to the same "concept" and you want to +verify it. In both cases, you want the same test logic repeated for different +types. + +While you can write one `TEST` or `TEST_F` for each type you want to test (and +you may even factor the test logic into a function template that you invoke from +the `TEST`), it's tedious and doesn't scale: if you want `m` tests over `n` +types, you'll end up writing `m*n` `TEST`s. + +*Typed tests* allow you to repeat the same test logic over a list of types. You +only need to write the test logic once, although you must know the type list +when writing typed tests. Here's how you do it: + +First, define a fixture class template. It should be parameterized by a type. +Remember to derive it from `::testing::Test`: + +```c++ +template +class FooTest : public testing::Test { + public: + ... + using List = std::list; + static T shared_; + T value_; +}; +``` + +Next, associate a list of types with the test suite, which will be repeated for +each type in the list: + +```c++ +using MyTypes = ::testing::Types; +TYPED_TEST_SUITE(FooTest, MyTypes); +``` + +The type alias (`using` or `typedef`) is necessary for the `TYPED_TEST_SUITE` +macro to parse correctly. Otherwise the compiler will think that each comma in +the type list introduces a new macro argument. + +Then, use `TYPED_TEST()` instead of `TEST_F()` to define a typed test for this +test suite. You can repeat this as many times as you want: + +```c++ +TYPED_TEST(FooTest, DoesBlah) { + // Inside a test, refer to the special name TypeParam to get the type + // parameter. Since we are inside a derived class template, C++ requires + // us to visit the members of FooTest via 'this'. + TypeParam n = this->value_; + + // To visit static members of the fixture, add the 'TestFixture::' + // prefix. + n += TestFixture::shared_; + + // To refer to typedefs in the fixture, add the 'typename TestFixture::' + // prefix. The 'typename' is required to satisfy the compiler. + typename TestFixture::List values; + + values.push_back(n); + ... +} + +TYPED_TEST(FooTest, HasPropertyA) { ... } +``` + +You can see [sample6_unittest.cc] for a complete example. + +[sample6_unittest.cc]: https://github.com/google/googletest/blob/main/googletest/samples/sample6_unittest.cc "Typed Test example" + +## Type-Parameterized Tests + +*Type-parameterized tests* are like typed tests, except that they don't require +you to know the list of types ahead of time. Instead, you can define the test +logic first and instantiate it with different type lists later. You can even +instantiate it more than once in the same program. + +If you are designing an interface or concept, you can define a suite of +type-parameterized tests to verify properties that any valid implementation of +the interface/concept should have. Then, the author of each implementation can +just instantiate the test suite with their type to verify that it conforms to +the requirements, without having to write similar tests repeatedly. Here's an +example: + +First, define a fixture class template, as we did with typed tests: + +```c++ +template +class FooTest : public testing::Test { + void DoSomethingInteresting(); + ... +}; +``` + +Next, declare that you will define a type-parameterized test suite: + +```c++ +TYPED_TEST_SUITE_P(FooTest); +``` + +Then, use `TYPED_TEST_P()` to define a type-parameterized test. You can repeat +this as many times as you want: + +```c++ +TYPED_TEST_P(FooTest, DoesBlah) { + // Inside a test, refer to TypeParam to get the type parameter. + TypeParam n = 0; + + // You will need to use `this` explicitly to refer to fixture members. + this->DoSomethingInteresting() + ... +} + +TYPED_TEST_P(FooTest, HasPropertyA) { ... } +``` + +Now the tricky part: you need to register all test patterns using the +`REGISTER_TYPED_TEST_SUITE_P` macro before you can instantiate them. The first +argument of the macro is the test suite name; the rest are the names of the +tests in this test suite: + +```c++ +REGISTER_TYPED_TEST_SUITE_P(FooTest, + DoesBlah, HasPropertyA); +``` + +Finally, you are free to instantiate the pattern with the types you want. If you +put the above code in a header file, you can `#include` it in multiple C++ +source files and instantiate it multiple times. + +```c++ +using MyTypes = ::testing::Types; +INSTANTIATE_TYPED_TEST_SUITE_P(My, FooTest, MyTypes); +``` + +To distinguish different instances of the pattern, the first argument to the +`INSTANTIATE_TYPED_TEST_SUITE_P` macro is a prefix that will be added to the +actual test suite name. Remember to pick unique prefixes for different +instances. + +In the special case where the type list contains only one type, you can write +that type directly without `::testing::Types<...>`, like this: + +```c++ +INSTANTIATE_TYPED_TEST_SUITE_P(My, FooTest, int); +``` + +You can see [sample6_unittest.cc] for a complete example. + +## Testing Private Code + +If you change your software's internal implementation, your tests should not +break as long as the change is not observable by users. Therefore, **per the +black-box testing principle, most of the time you should test your code through +its public interfaces.** + +**If you still find yourself needing to test internal implementation code, +consider if there's a better design.** The desire to test internal +implementation is often a sign that the class is doing too much. Consider +extracting an implementation class, and testing it. Then use that implementation +class in the original class. + +If you absolutely have to test non-public interface code though, you can. There +are two cases to consider: + +* Static functions ( *not* the same as static member functions!) or unnamed + namespaces, and +* Private or protected class members + +To test them, we use the following special techniques: + +* Both static functions and definitions/declarations in an unnamed namespace + are only visible within the same translation unit. To test them, you can + `#include` the entire `.cc` file being tested in your `*_test.cc` file. + (#including `.cc` files is not a good way to reuse code - you should not do + this in production code!) + + However, a better approach is to move the private code into the + `foo::internal` namespace, where `foo` is the namespace your project + normally uses, and put the private declarations in a `*-internal.h` file. + Your production `.cc` files and your tests are allowed to include this + internal header, but your clients are not. This way, you can fully test your + internal implementation without leaking it to your clients. + +* Private class members are only accessible from within the class or by + friends. To access a class' private members, you can declare your test + fixture as a friend to the class and define accessors in your fixture. Tests + using the fixture can then access the private members of your production + class via the accessors in the fixture. Note that even though your fixture + is a friend to your production class, your tests are not automatically + friends to it, as they are technically defined in sub-classes of the + fixture. + + Another way to test private members is to refactor them into an + implementation class, which is then declared in a `*-internal.h` file. Your + clients aren't allowed to include this header but your tests can. Such is + called the + [Pimpl](https://www.gamedev.net/articles/programming/general-and-gameplay-programming/the-c-pimpl-r1794/) + (Private Implementation) idiom. + + Or, you can declare an individual test as a friend of your class by adding + this line in the class body: + + ```c++ + FRIEND_TEST(TestSuiteName, TestName); + ``` + + For example, + + ```c++ + // foo.h + class Foo { + ... + private: + FRIEND_TEST(FooTest, BarReturnsZeroOnNull); + + int Bar(void* x); + }; + + // foo_test.cc + ... + TEST(FooTest, BarReturnsZeroOnNull) { + Foo foo; + EXPECT_EQ(foo.Bar(NULL), 0); // Uses Foo's private member Bar(). + } + ``` + + Pay special attention when your class is defined in a namespace. If you want + your test fixtures and tests to be friends of your class, then they must be + defined in the exact same namespace (no anonymous or inline namespaces). + + For example, if the code to be tested looks like: + + ```c++ + namespace my_namespace { + + class Foo { + friend class FooTest; + FRIEND_TEST(FooTest, Bar); + FRIEND_TEST(FooTest, Baz); + ... definition of the class Foo ... + }; + + } // namespace my_namespace + ``` + + Your test code should be something like: + + ```c++ + namespace my_namespace { + + class FooTest : public testing::Test { + protected: + ... + }; + + TEST_F(FooTest, Bar) { ... } + TEST_F(FooTest, Baz) { ... } + + } // namespace my_namespace + ``` + +## "Catching" Failures + +If you are building a testing utility on top of googletest, you'll want to test +your utility. What framework would you use to test it? googletest, of course. + +The challenge is to verify that your testing utility reports failures correctly. +In frameworks that report a failure by throwing an exception, you could catch +the exception and assert on it. But googletest doesn't use exceptions, so how do +we test that a piece of code generates an expected failure? + +`"gtest/gtest-spi.h"` contains some constructs to do this. +After #including this header, you can use + +```c++ + EXPECT_FATAL_FAILURE(statement, substring); +``` + +to assert that `statement` generates a fatal (e.g. `ASSERT_*`) failure in the +current thread whose message contains the given `substring`, or use + +```c++ + EXPECT_NONFATAL_FAILURE(statement, substring); +``` + +if you are expecting a non-fatal (e.g. `EXPECT_*`) failure. + +Only failures in the current thread are checked to determine the result of this +type of expectations. If `statement` creates new threads, failures in these +threads are also ignored. If you want to catch failures in other threads as +well, use one of the following macros instead: + +```c++ + EXPECT_FATAL_FAILURE_ON_ALL_THREADS(statement, substring); + EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS(statement, substring); +``` + +{: .callout .note} +NOTE: Assertions from multiple threads are currently not supported on Windows. + +For technical reasons, there are some caveats: + +1. You cannot stream a failure message to either macro. + +2. `statement` in `EXPECT_FATAL_FAILURE{_ON_ALL_THREADS}()` cannot reference + local non-static variables or non-static members of `this` object. + +3. `statement` in `EXPECT_FATAL_FAILURE{_ON_ALL_THREADS}()` cannot return a + value. + +## Registering tests programmatically + +The `TEST` macros handle the vast majority of all use cases, but there are few +where runtime registration logic is required. For those cases, the framework +provides the `::testing::RegisterTest` that allows callers to register arbitrary +tests dynamically. + +This is an advanced API only to be used when the `TEST` macros are insufficient. +The macros should be preferred when possible, as they avoid most of the +complexity of calling this function. + +It provides the following signature: + +```c++ +template +TestInfo* RegisterTest(const char* test_suite_name, const char* test_name, + const char* type_param, const char* value_param, + const char* file, int line, Factory factory); +``` + +The `factory` argument is a factory callable (move-constructible) object or +function pointer that creates a new instance of the Test object. It handles +ownership to the caller. The signature of the callable is `Fixture*()`, where +`Fixture` is the test fixture class for the test. All tests registered with the +same `test_suite_name` must return the same fixture type. This is checked at +runtime. + +The framework will infer the fixture class from the factory and will call the +`SetUpTestSuite` and `TearDownTestSuite` for it. + +Must be called before `RUN_ALL_TESTS()` is invoked, otherwise behavior is +undefined. + +Use case example: + +```c++ +class MyFixture : public testing::Test { + public: + // All of these optional, just like in regular macro usage. + static void SetUpTestSuite() { ... } + static void TearDownTestSuite() { ... } + void SetUp() override { ... } + void TearDown() override { ... } +}; + +class MyTest : public MyFixture { + public: + explicit MyTest(int data) : data_(data) {} + void TestBody() override { ... } + + private: + int data_; +}; + +void RegisterMyTests(const std::vector& values) { + for (int v : values) { + testing::RegisterTest( + "MyFixture", ("Test" + std::to_string(v)).c_str(), nullptr, + std::to_string(v).c_str(), + __FILE__, __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MyFixture* { return new MyTest(v); }); + } +} +... +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + std::vector values_to_test = LoadValuesFromConfig(); + RegisterMyTests(values_to_test); + ... + return RUN_ALL_TESTS(); +} +``` + +## Getting the Current Test's Name + +Sometimes a function may need to know the name of the currently running test. +For example, you may be using the `SetUp()` method of your test fixture to set +the golden file name based on which test is running. The +[`TestInfo`](reference/testing.md#TestInfo) class has this information. + +To obtain a `TestInfo` object for the currently running test, call +`current_test_info()` on the [`UnitTest`](reference/testing.md#UnitTest) +singleton object: + +```c++ + // Gets information about the currently running test. + // Do NOT delete the returned object - it's managed by the UnitTest class. + const testing::TestInfo* const test_info = + testing::UnitTest::GetInstance()->current_test_info(); + + printf("We are in test %s of test suite %s.\n", + test_info->name(), + test_info->test_suite_name()); +``` + +`current_test_info()` returns a null pointer if no test is running. In +particular, you cannot find the test suite name in `SetUpTestSuite()`, +`TearDownTestSuite()` (where you know the test suite name implicitly), or +functions called from them. + +## Extending googletest by Handling Test Events + +googletest provides an **event listener API** to let you receive notifications +about the progress of a test program and test failures. The events you can +listen to include the start and end of the test program, a test suite, or a test +method, among others. You may use this API to augment or replace the standard +console output, replace the XML output, or provide a completely different form +of output, such as a GUI or a database. You can also use test events as +checkpoints to implement a resource leak checker, for example. + +### Defining Event Listeners + +To define a event listener, you subclass either +[`testing::TestEventListener`](reference/testing.md#TestEventListener) or +[`testing::EmptyTestEventListener`](reference/testing.md#EmptyTestEventListener) +The former is an (abstract) interface, where *each pure virtual method can be +overridden to handle a test event* (For example, when a test starts, the +`OnTestStart()` method will be called.). The latter provides an empty +implementation of all methods in the interface, such that a subclass only needs +to override the methods it cares about. + +When an event is fired, its context is passed to the handler function as an +argument. The following argument types are used: + +* UnitTest reflects the state of the entire test program, +* TestSuite has information about a test suite, which can contain one or more + tests, +* TestInfo contains the state of a test, and +* TestPartResult represents the result of a test assertion. + +An event handler function can examine the argument it receives to find out +interesting information about the event and the test program's state. + +Here's an example: + +```c++ + class MinimalistPrinter : public testing::EmptyTestEventListener { + // Called before a test starts. + void OnTestStart(const testing::TestInfo& test_info) override { + printf("*** Test %s.%s starting.\n", + test_info.test_suite_name(), test_info.name()); + } + + // Called after a failed assertion or a SUCCESS(). + void OnTestPartResult(const testing::TestPartResult& test_part_result) override { + printf("%s in %s:%d\n%s\n", + test_part_result.failed() ? "*** Failure" : "Success", + test_part_result.file_name(), + test_part_result.line_number(), + test_part_result.summary()); + } + + // Called after a test ends. + void OnTestEnd(const testing::TestInfo& test_info) override { + printf("*** Test %s.%s ending.\n", + test_info.test_suite_name(), test_info.name()); + } + }; +``` + +### Using Event Listeners + +To use the event listener you have defined, add an instance of it to the +googletest event listener list (represented by class +[`TestEventListeners`](reference/testing.md#TestEventListeners) - note the "s" +at the end of the name) in your `main()` function, before calling +`RUN_ALL_TESTS()`: + +```c++ +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + // Gets hold of the event listener list. + testing::TestEventListeners& listeners = + testing::UnitTest::GetInstance()->listeners(); + // Adds a listener to the end. googletest takes the ownership. + listeners.Append(new MinimalistPrinter); + return RUN_ALL_TESTS(); +} +``` + +There's only one problem: the default test result printer is still in effect, so +its output will mingle with the output from your minimalist printer. To suppress +the default printer, just release it from the event listener list and delete it. +You can do so by adding one line: + +```c++ + ... + delete listeners.Release(listeners.default_result_printer()); + listeners.Append(new MinimalistPrinter); + return RUN_ALL_TESTS(); +``` + +Now, sit back and enjoy a completely different output from your tests. For more +details, see [sample9_unittest.cc]. + +[sample9_unittest.cc]: https://github.com/google/googletest/blob/main/googletest/samples/sample9_unittest.cc "Event listener example" + +You may append more than one listener to the list. When an `On*Start()` or +`OnTestPartResult()` event is fired, the listeners will receive it in the order +they appear in the list (since new listeners are added to the end of the list, +the default text printer and the default XML generator will receive the event +first). An `On*End()` event will be received by the listeners in the *reverse* +order. This allows output by listeners added later to be framed by output from +listeners added earlier. + +### Generating Failures in Listeners + +You may use failure-raising macros (`EXPECT_*()`, `ASSERT_*()`, `FAIL()`, etc) +when processing an event. There are some restrictions: + +1. You cannot generate any failure in `OnTestPartResult()` (otherwise it will + cause `OnTestPartResult()` to be called recursively). +2. A listener that handles `OnTestPartResult()` is not allowed to generate any + failure. + +When you add listeners to the listener list, you should put listeners that +handle `OnTestPartResult()` *before* listeners that can generate failures. This +ensures that failures generated by the latter are attributed to the right test +by the former. + +See [sample10_unittest.cc] for an example of a failure-raising listener. + +[sample10_unittest.cc]: https://github.com/google/googletest/blob/main/googletest/samples/sample10_unittest.cc "Failure-raising listener example" + +## Running Test Programs: Advanced Options + +googletest test programs are ordinary executables. Once built, you can run them +directly and affect their behavior via the following environment variables +and/or command line flags. For the flags to work, your programs must call +`::testing::InitGoogleTest()` before calling `RUN_ALL_TESTS()`. + +To see a list of supported flags and their usage, please run your test program +with the `--help` flag. You can also use `-h`, `-?`, or `/?` for short. + +If an option is specified both by an environment variable and by a flag, the +latter takes precedence. + +### Selecting Tests + +#### Listing Test Names + +Sometimes it is necessary to list the available tests in a program before +running them so that a filter may be applied if needed. Including the flag +`--gtest_list_tests` overrides all other flags and lists tests in the following +format: + +```none +TestSuite1. + TestName1 + TestName2 +TestSuite2. + TestName +``` + +None of the tests listed are actually run if the flag is provided. There is no +corresponding environment variable for this flag. + +#### Running a Subset of the Tests + +By default, a googletest program runs all tests the user has defined. Sometimes, +you want to run only a subset of the tests (e.g. for debugging or quickly +verifying a change). If you set the `GTEST_FILTER` environment variable or the +`--gtest_filter` flag to a filter string, googletest will only run the tests +whose full names (in the form of `TestSuiteName.TestName`) match the filter. + +The format of a filter is a '`:`'-separated list of wildcard patterns (called +the *positive patterns*) optionally followed by a '`-`' and another +'`:`'-separated pattern list (called the *negative patterns*). A test matches +the filter if and only if it matches any of the positive patterns but does not +match any of the negative patterns. + +A pattern may contain `'*'` (matches any string) or `'?'` (matches any single +character). For convenience, the filter `'*-NegativePatterns'` can be also +written as `'-NegativePatterns'`. + +For example: + +* `./foo_test` Has no flag, and thus runs all its tests. +* `./foo_test --gtest_filter=*` Also runs everything, due to the single + match-everything `*` value. +* `./foo_test --gtest_filter=FooTest.*` Runs everything in test suite + `FooTest` . +* `./foo_test --gtest_filter=*Null*:*Constructor*` Runs any test whose full + name contains either `"Null"` or `"Constructor"` . +* `./foo_test --gtest_filter=-*DeathTest.*` Runs all non-death tests. +* `./foo_test --gtest_filter=FooTest.*-FooTest.Bar` Runs everything in test + suite `FooTest` except `FooTest.Bar`. +* `./foo_test --gtest_filter=FooTest.*:BarTest.*-FooTest.Bar:BarTest.Foo` Runs + everything in test suite `FooTest` except `FooTest.Bar` and everything in + test suite `BarTest` except `BarTest.Foo`. + +#### Stop test execution upon first failure + +By default, a googletest program runs all tests the user has defined. In some +cases (e.g. iterative test development & execution) it may be desirable stop +test execution upon first failure (trading improved latency for completeness). +If `GTEST_FAIL_FAST` environment variable or `--gtest_fail_fast` flag is set, +the test runner will stop execution as soon as the first test failure is found. + +#### Temporarily Disabling Tests + +If you have a broken test that you cannot fix right away, you can add the +`DISABLED_` prefix to its name. This will exclude it from execution. This is +better than commenting out the code or using `#if 0`, as disabled tests are +still compiled (and thus won't rot). + +If you need to disable all tests in a test suite, you can either add `DISABLED_` +to the front of the name of each test, or alternatively add it to the front of +the test suite name. + +For example, the following tests won't be run by googletest, even though they +will still be compiled: + +```c++ +// Tests that Foo does Abc. +TEST(FooTest, DISABLED_DoesAbc) { ... } + +class DISABLED_BarTest : public testing::Test { ... }; + +// Tests that Bar does Xyz. +TEST_F(DISABLED_BarTest, DoesXyz) { ... } +``` + +{: .callout .note} +NOTE: This feature should only be used for temporary pain-relief. You still have +to fix the disabled tests at a later date. As a reminder, googletest will print +a banner warning you if a test program contains any disabled tests. + +{: .callout .tip} +TIP: You can easily count the number of disabled tests you have using +`grep`. This number can be used as a metric for +improving your test quality. + +#### Temporarily Enabling Disabled Tests + +To include disabled tests in test execution, just invoke the test program with +the `--gtest_also_run_disabled_tests` flag or set the +`GTEST_ALSO_RUN_DISABLED_TESTS` environment variable to a value other than `0`. +You can combine this with the `--gtest_filter` flag to further select which +disabled tests to run. + +### Repeating the Tests + +Once in a while you'll run into a test whose result is hit-or-miss. Perhaps it +will fail only 1% of the time, making it rather hard to reproduce the bug under +a debugger. This can be a major source of frustration. + +The `--gtest_repeat` flag allows you to repeat all (or selected) test methods in +a program many times. Hopefully, a flaky test will eventually fail and give you +a chance to debug. Here's how to use it: + +```none +$ foo_test --gtest_repeat=1000 +Repeat foo_test 1000 times and don't stop at failures. + +$ foo_test --gtest_repeat=-1 +A negative count means repeating forever. + +$ foo_test --gtest_repeat=1000 --gtest_break_on_failure +Repeat foo_test 1000 times, stopping at the first failure. This +is especially useful when running under a debugger: when the test +fails, it will drop into the debugger and you can then inspect +variables and stacks. + +$ foo_test --gtest_repeat=1000 --gtest_filter=FooBar.* +Repeat the tests whose name matches the filter 1000 times. +``` + +If your test program contains +[global set-up/tear-down](#global-set-up-and-tear-down) code, it will be +repeated in each iteration as well, as the flakiness may be in it. To avoid +repeating global set-up/tear-down, specify +`--gtest_recreate_environments_when_repeating=false`{.nowrap}. + +You can also specify the repeat count by setting the `GTEST_REPEAT` environment +variable. + +### Shuffling the Tests + +You can specify the `--gtest_shuffle` flag (or set the `GTEST_SHUFFLE` +environment variable to `1`) to run the tests in a program in a random order. +This helps to reveal bad dependencies between tests. + +By default, googletest uses a random seed calculated from the current time. +Therefore you'll get a different order every time. The console output includes +the random seed value, such that you can reproduce an order-related test failure +later. To specify the random seed explicitly, use the `--gtest_random_seed=SEED` +flag (or set the `GTEST_RANDOM_SEED` environment variable), where `SEED` is an +integer in the range [0, 99999]. The seed value 0 is special: it tells +googletest to do the default behavior of calculating the seed from the current +time. + +If you combine this with `--gtest_repeat=N`, googletest will pick a different +random seed and re-shuffle the tests in each iteration. + +### Distributing Test Functions to Multiple Machines + +If you have more than one machine you can use to run a test program, you might +want to run the test functions in parallel and get the result faster. We call +this technique *sharding*, where each machine is called a *shard*. + +GoogleTest is compatible with test sharding. To take advantage of this feature, +your test runner (not part of GoogleTest) needs to do the following: + +1. Allocate a number of machines (shards) to run the tests. +1. On each shard, set the `GTEST_TOTAL_SHARDS` environment variable to the total + number of shards. It must be the same for all shards. +1. On each shard, set the `GTEST_SHARD_INDEX` environment variable to the index + of the shard. Different shards must be assigned different indices, which + must be in the range `[0, GTEST_TOTAL_SHARDS - 1]`. +1. Run the same test program on all shards. When GoogleTest sees the above two + environment variables, it will select a subset of the test functions to run. + Across all shards, each test function in the program will be run exactly + once. +1. Wait for all shards to finish, then collect and report the results. + +Your project may have tests that were written without GoogleTest and thus don't +understand this protocol. In order for your test runner to figure out which test +supports sharding, it can set the environment variable `GTEST_SHARD_STATUS_FILE` +to a non-existent file path. If a test program supports sharding, it will create +this file to acknowledge that fact; otherwise it will not create it. The actual +contents of the file are not important at this time, although we may put some +useful information in it in the future. + +Here's an example to make it clear. Suppose you have a test program `foo_test` +that contains the following 5 test functions: + +``` +TEST(A, V) +TEST(A, W) +TEST(B, X) +TEST(B, Y) +TEST(B, Z) +``` + +Suppose you have 3 machines at your disposal. To run the test functions in +parallel, you would set `GTEST_TOTAL_SHARDS` to 3 on all machines, and set +`GTEST_SHARD_INDEX` to 0, 1, and 2 on the machines respectively. Then you would +run the same `foo_test` on each machine. + +GoogleTest reserves the right to change how the work is distributed across the +shards, but here's one possible scenario: + +* Machine #0 runs `A.V` and `B.X`. +* Machine #1 runs `A.W` and `B.Y`. +* Machine #2 runs `B.Z`. + +### Controlling Test Output + +#### Colored Terminal Output + +googletest can use colors in its terminal output to make it easier to spot the +important information: + +
...
+[----------] 1 test from FooTest
+[ RUN      ] FooTest.DoesAbc
+[       OK ] FooTest.DoesAbc
+[----------] 2 tests from BarTest
+[ RUN      ] BarTest.HasXyzProperty
+[       OK ] BarTest.HasXyzProperty
+[ RUN      ] BarTest.ReturnsTrueOnSuccess
+... some error messages ...
+[   FAILED ] BarTest.ReturnsTrueOnSuccess
+...
+[==========] 30 tests from 14 test suites ran.
+[   PASSED ] 28 tests.
+[   FAILED ] 2 tests, listed below:
+[   FAILED ] BarTest.ReturnsTrueOnSuccess
+[   FAILED ] AnotherTest.DoesXyz
+
+ 2 FAILED TESTS
+
+ +You can set the `GTEST_COLOR` environment variable or the `--gtest_color` +command line flag to `yes`, `no`, or `auto` (the default) to enable colors, +disable colors, or let googletest decide. When the value is `auto`, googletest +will use colors if and only if the output goes to a terminal and (on non-Windows +platforms) the `TERM` environment variable is set to `xterm` or `xterm-color`. + +#### Suppressing test passes + +By default, googletest prints 1 line of output for each test, indicating if it +passed or failed. To show only test failures, run the test program with +`--gtest_brief=1`, or set the GTEST_BRIEF environment variable to `1`. + +#### Suppressing the Elapsed Time + +By default, googletest prints the time it takes to run each test. To disable +that, run the test program with the `--gtest_print_time=0` command line flag, or +set the GTEST_PRINT_TIME environment variable to `0`. + +#### Suppressing UTF-8 Text Output + +In case of assertion failures, googletest prints expected and actual values of +type `string` both as hex-encoded strings as well as in readable UTF-8 text if +they contain valid non-ASCII UTF-8 characters. If you want to suppress the UTF-8 +text because, for example, you don't have an UTF-8 compatible output medium, run +the test program with `--gtest_print_utf8=0` or set the `GTEST_PRINT_UTF8` +environment variable to `0`. + +#### Generating an XML Report + +googletest can emit a detailed XML report to a file in addition to its normal +textual output. The report contains the duration of each test, and thus can help +you identify slow tests. + +To generate the XML report, set the `GTEST_OUTPUT` environment variable or the +`--gtest_output` flag to the string `"xml:path_to_output_file"`, which will +create the file at the given location. You can also just use the string `"xml"`, +in which case the output can be found in the `test_detail.xml` file in the +current directory. + +If you specify a directory (for example, `"xml:output/directory/"` on Linux or +`"xml:output\directory\"` on Windows), googletest will create the XML file in +that directory, named after the test executable (e.g. `foo_test.xml` for test +program `foo_test` or `foo_test.exe`). If the file already exists (perhaps left +over from a previous run), googletest will pick a different name (e.g. +`foo_test_1.xml`) to avoid overwriting it. + +The report is based on the `junitreport` Ant task. Since that format was +originally intended for Java, a little interpretation is required to make it +apply to googletest tests, as shown here: + +```xml + + + + + + + + + +``` + +* The root `` element corresponds to the entire test program. +* `` elements correspond to googletest test suites. +* `` elements correspond to googletest test functions. + +For instance, the following program + +```c++ +TEST(MathTest, Addition) { ... } +TEST(MathTest, Subtraction) { ... } +TEST(LogicTest, NonContradiction) { ... } +``` + +could generate this report: + +```xml + + + + + ... + ... + + + + + + + + + +``` + +Things to note: + +* The `tests` attribute of a `` or `` element tells how + many test functions the googletest program or test suite contains, while the + `failures` attribute tells how many of them failed. + +* The `time` attribute expresses the duration of the test, test suite, or + entire test program in seconds. + +* The `timestamp` attribute records the local date and time of the test + execution. + +* The `file` and `line` attributes record the source file location, where the + test was defined. + +* Each `` element corresponds to a single failed googletest + assertion. + +#### Generating a JSON Report + +googletest can also emit a JSON report as an alternative format to XML. To +generate the JSON report, set the `GTEST_OUTPUT` environment variable or the +`--gtest_output` flag to the string `"json:path_to_output_file"`, which will +create the file at the given location. You can also just use the string +`"json"`, in which case the output can be found in the `test_detail.json` file +in the current directory. + +The report format conforms to the following JSON Schema: + +```json +{ + "$schema": "http://json-schema.org/schema#", + "type": "object", + "definitions": { + "TestCase": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "tests": { "type": "integer" }, + "failures": { "type": "integer" }, + "disabled": { "type": "integer" }, + "time": { "type": "string" }, + "testsuite": { + "type": "array", + "items": { + "$ref": "#/definitions/TestInfo" + } + } + } + }, + "TestInfo": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "file": { "type": "string" }, + "line": { "type": "integer" }, + "status": { + "type": "string", + "enum": ["RUN", "NOTRUN"] + }, + "time": { "type": "string" }, + "classname": { "type": "string" }, + "failures": { + "type": "array", + "items": { + "$ref": "#/definitions/Failure" + } + } + } + }, + "Failure": { + "type": "object", + "properties": { + "failures": { "type": "string" }, + "type": { "type": "string" } + } + } + }, + "properties": { + "tests": { "type": "integer" }, + "failures": { "type": "integer" }, + "disabled": { "type": "integer" }, + "errors": { "type": "integer" }, + "timestamp": { + "type": "string", + "format": "date-time" + }, + "time": { "type": "string" }, + "name": { "type": "string" }, + "testsuites": { + "type": "array", + "items": { + "$ref": "#/definitions/TestCase" + } + } + } +} +``` + +The report uses the format that conforms to the following Proto3 using the +[JSON encoding](https://developers.google.com/protocol-buffers/docs/proto3#json): + +```proto +syntax = "proto3"; + +package googletest; + +import "google/protobuf/timestamp.proto"; +import "google/protobuf/duration.proto"; + +message UnitTest { + int32 tests = 1; + int32 failures = 2; + int32 disabled = 3; + int32 errors = 4; + google.protobuf.Timestamp timestamp = 5; + google.protobuf.Duration time = 6; + string name = 7; + repeated TestCase testsuites = 8; +} + +message TestCase { + string name = 1; + int32 tests = 2; + int32 failures = 3; + int32 disabled = 4; + int32 errors = 5; + google.protobuf.Duration time = 6; + repeated TestInfo testsuite = 7; +} + +message TestInfo { + string name = 1; + string file = 6; + int32 line = 7; + enum Status { + RUN = 0; + NOTRUN = 1; + } + Status status = 2; + google.protobuf.Duration time = 3; + string classname = 4; + message Failure { + string failures = 1; + string type = 2; + } + repeated Failure failures = 5; +} +``` + +For instance, the following program + +```c++ +TEST(MathTest, Addition) { ... } +TEST(MathTest, Subtraction) { ... } +TEST(LogicTest, NonContradiction) { ... } +``` + +could generate this report: + +```json +{ + "tests": 3, + "failures": 1, + "errors": 0, + "time": "0.035s", + "timestamp": "2011-10-31T18:52:42Z", + "name": "AllTests", + "testsuites": [ + { + "name": "MathTest", + "tests": 2, + "failures": 1, + "errors": 0, + "time": "0.015s", + "testsuite": [ + { + "name": "Addition", + "file": "test.cpp", + "line": 1, + "status": "RUN", + "time": "0.007s", + "classname": "", + "failures": [ + { + "message": "Value of: add(1, 1)\n Actual: 3\nExpected: 2", + "type": "" + }, + { + "message": "Value of: add(1, -1)\n Actual: 1\nExpected: 0", + "type": "" + } + ] + }, + { + "name": "Subtraction", + "file": "test.cpp", + "line": 2, + "status": "RUN", + "time": "0.005s", + "classname": "" + } + ] + }, + { + "name": "LogicTest", + "tests": 1, + "failures": 0, + "errors": 0, + "time": "0.005s", + "testsuite": [ + { + "name": "NonContradiction", + "file": "test.cpp", + "line": 3, + "status": "RUN", + "time": "0.005s", + "classname": "" + } + ] + } + ] +} +``` + +{: .callout .important} +IMPORTANT: The exact format of the JSON document is subject to change. + +### Controlling How Failures Are Reported + +#### Detecting Test Premature Exit + +Google Test implements the _premature-exit-file_ protocol for test runners to +catch any kind of unexpected exits of test programs. Upon start, Google Test +creates the file which will be automatically deleted after all work has been +finished. Then, the test runner can check if this file exists. In case the file +remains undeleted, the inspected test has exited prematurely. + +This feature is enabled only if the `TEST_PREMATURE_EXIT_FILE` environment +variable has been set. + +#### Turning Assertion Failures into Break-Points + +When running test programs under a debugger, it's very convenient if the +debugger can catch an assertion failure and automatically drop into interactive +mode. googletest's *break-on-failure* mode supports this behavior. + +To enable it, set the `GTEST_BREAK_ON_FAILURE` environment variable to a value +other than `0`. Alternatively, you can use the `--gtest_break_on_failure` +command line flag. + +#### Disabling Catching Test-Thrown Exceptions + +googletest can be used either with or without exceptions enabled. If a test +throws a C++ exception or (on Windows) a structured exception (SEH), by default +googletest catches it, reports it as a test failure, and continues with the next +test method. This maximizes the coverage of a test run. Also, on Windows an +uncaught exception will cause a pop-up window, so catching the exceptions allows +you to run the tests automatically. + +When debugging the test failures, however, you may instead want the exceptions +to be handled by the debugger, such that you can examine the call stack when an +exception is thrown. To achieve that, set the `GTEST_CATCH_EXCEPTIONS` +environment variable to `0`, or use the `--gtest_catch_exceptions=0` flag when +running the tests. + +### Sanitizer Integration + +The +[Undefined Behavior Sanitizer](https://clang.llvm.org/docs/UndefinedBehaviorSanitizer.html), +[Address Sanitizer](https://github.com/google/sanitizers/wiki/AddressSanitizer), +and +[Thread Sanitizer](https://github.com/google/sanitizers/wiki/ThreadSanitizerCppManual) +all provide weak functions that you can override to trigger explicit failures +when they detect sanitizer errors, such as creating a reference from `nullptr`. +To override these functions, place definitions for them in a source file that +you compile as part of your main binary: + +``` +extern "C" { +void __ubsan_on_report() { + FAIL() << "Encountered an undefined behavior sanitizer error"; +} +void __asan_on_error() { + FAIL() << "Encountered an address sanitizer error"; +} +void __tsan_on_report() { + FAIL() << "Encountered a thread sanitizer error"; +} +} // extern "C" +``` + +After compiling your project with one of the sanitizers enabled, if a particular +test triggers a sanitizer error, googletest will report that it failed. diff --git a/3rdparty/googletest-1.13.0/docs/assets/css/style.scss b/3rdparty/googletest-1.13.0/docs/assets/css/style.scss new file mode 100644 index 0000000000000000000000000000000000000000..bb30f418da7b92d8eaa1fd63caac99aa0576e91c --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/assets/css/style.scss @@ -0,0 +1,5 @@ +--- +--- + +@import "jekyll-theme-primer"; +@import "main"; diff --git a/3rdparty/googletest-1.13.0/docs/community_created_documentation.md b/3rdparty/googletest-1.13.0/docs/community_created_documentation.md new file mode 100644 index 0000000000000000000000000000000000000000..4569075ff23be385fce1656a8db2f77631d00d42 --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/community_created_documentation.md @@ -0,0 +1,7 @@ +# Community-Created Documentation + +The following is a list, in no particular order, of links to documentation +created by the Googletest community. + +* [Googlemock Insights](https://github.com/ElectricRCAircraftGuy/eRCaGuy_dotfiles/blob/master/googletest/insights.md), + by [ElectricRCAircraftGuy](https://github.com/ElectricRCAircraftGuy) diff --git a/3rdparty/googletest-1.13.0/docs/faq.md b/3rdparty/googletest-1.13.0/docs/faq.md new file mode 100644 index 0000000000000000000000000000000000000000..1928097292a238a81269a01c8e6bd16c96c61b9b --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/faq.md @@ -0,0 +1,692 @@ +# GoogleTest FAQ + +## Why should test suite names and test names not contain underscore? + +{: .callout .note} +Note: GoogleTest reserves underscore (`_`) for special purpose keywords, such as +[the `DISABLED_` prefix](advanced.md#temporarily-disabling-tests), in addition +to the following rationale. + +Underscore (`_`) is special, as C++ reserves the following to be used by the +compiler and the standard library: + +1. any identifier that starts with an `_` followed by an upper-case letter, and +2. any identifier that contains two consecutive underscores (i.e. `__`) + *anywhere* in its name. + +User code is *prohibited* from using such identifiers. + +Now let's look at what this means for `TEST` and `TEST_F`. + +Currently `TEST(TestSuiteName, TestName)` generates a class named +`TestSuiteName_TestName_Test`. What happens if `TestSuiteName` or `TestName` +contains `_`? + +1. If `TestSuiteName` starts with an `_` followed by an upper-case letter (say, + `_Foo`), we end up with `_Foo_TestName_Test`, which is reserved and thus + invalid. +2. If `TestSuiteName` ends with an `_` (say, `Foo_`), we get + `Foo__TestName_Test`, which is invalid. +3. If `TestName` starts with an `_` (say, `_Bar`), we get + `TestSuiteName__Bar_Test`, which is invalid. +4. If `TestName` ends with an `_` (say, `Bar_`), we get + `TestSuiteName_Bar__Test`, which is invalid. + +So clearly `TestSuiteName` and `TestName` cannot start or end with `_` +(Actually, `TestSuiteName` can start with `_` -- as long as the `_` isn't +followed by an upper-case letter. But that's getting complicated. So for +simplicity we just say that it cannot start with `_`.). + +It may seem fine for `TestSuiteName` and `TestName` to contain `_` in the +middle. However, consider this: + +```c++ +TEST(Time, Flies_Like_An_Arrow) { ... } +TEST(Time_Flies, Like_An_Arrow) { ... } +``` + +Now, the two `TEST`s will both generate the same class +(`Time_Flies_Like_An_Arrow_Test`). That's not good. + +So for simplicity, we just ask the users to avoid `_` in `TestSuiteName` and +`TestName`. The rule is more constraining than necessary, but it's simple and +easy to remember. It also gives GoogleTest some wiggle room in case its +implementation needs to change in the future. + +If you violate the rule, there may not be immediate consequences, but your test +may (just may) break with a new compiler (or a new version of the compiler you +are using) or with a new version of GoogleTest. Therefore it's best to follow +the rule. + +## Why does GoogleTest support `EXPECT_EQ(NULL, ptr)` and `ASSERT_EQ(NULL, ptr)` but not `EXPECT_NE(NULL, ptr)` and `ASSERT_NE(NULL, ptr)`? + +First of all, you can use `nullptr` with each of these macros, e.g. +`EXPECT_EQ(ptr, nullptr)`, `EXPECT_NE(ptr, nullptr)`, `ASSERT_EQ(ptr, nullptr)`, +`ASSERT_NE(ptr, nullptr)`. This is the preferred syntax in the style guide +because `nullptr` does not have the type problems that `NULL` does. + +Due to some peculiarity of C++, it requires some non-trivial template meta +programming tricks to support using `NULL` as an argument of the `EXPECT_XX()` +and `ASSERT_XX()` macros. Therefore we only do it where it's most needed +(otherwise we make the implementation of GoogleTest harder to maintain and more +error-prone than necessary). + +Historically, the `EXPECT_EQ()` macro took the *expected* value as its first +argument and the *actual* value as the second, though this argument order is now +discouraged. It was reasonable that someone wanted +to write `EXPECT_EQ(NULL, some_expression)`, and this indeed was requested +several times. Therefore we implemented it. + +The need for `EXPECT_NE(NULL, ptr)` wasn't nearly as strong. When the assertion +fails, you already know that `ptr` must be `NULL`, so it doesn't add any +information to print `ptr` in this case. That means `EXPECT_TRUE(ptr != NULL)` +works just as well. + +If we were to support `EXPECT_NE(NULL, ptr)`, for consistency we'd have to +support `EXPECT_NE(ptr, NULL)` as well. This means using the template meta +programming tricks twice in the implementation, making it even harder to +understand and maintain. We believe the benefit doesn't justify the cost. + +Finally, with the growth of the gMock matcher library, we are encouraging people +to use the unified `EXPECT_THAT(value, matcher)` syntax more often in tests. One +significant advantage of the matcher approach is that matchers can be easily +combined to form new matchers, while the `EXPECT_NE`, etc, macros cannot be +easily combined. Therefore we want to invest more in the matchers than in the +`EXPECT_XX()` macros. + +## I need to test that different implementations of an interface satisfy some common requirements. Should I use typed tests or value-parameterized tests? + +For testing various implementations of the same interface, either typed tests or +value-parameterized tests can get it done. It's really up to you the user to +decide which is more convenient for you, depending on your particular case. Some +rough guidelines: + +* Typed tests can be easier to write if instances of the different + implementations can be created the same way, modulo the type. For example, + if all these implementations have a public default constructor (such that + you can write `new TypeParam`), or if their factory functions have the same + form (e.g. `CreateInstance()`). +* Value-parameterized tests can be easier to write if you need different code + patterns to create different implementations' instances, e.g. `new Foo` vs + `new Bar(5)`. To accommodate for the differences, you can write factory + function wrappers and pass these function pointers to the tests as their + parameters. +* When a typed test fails, the default output includes the name of the type, + which can help you quickly identify which implementation is wrong. + Value-parameterized tests only show the number of the failed iteration by + default. You will need to define a function that returns the iteration name + and pass it as the third parameter to INSTANTIATE_TEST_SUITE_P to have more + useful output. +* When using typed tests, you need to make sure you are testing against the + interface type, not the concrete types (in other words, you want to make + sure `implicit_cast(my_concrete_impl)` works, not just that + `my_concrete_impl` works). It's less likely to make mistakes in this area + when using value-parameterized tests. + +I hope I didn't confuse you more. :-) If you don't mind, I'd suggest you to give +both approaches a try. Practice is a much better way to grasp the subtle +differences between the two tools. Once you have some concrete experience, you +can much more easily decide which one to use the next time. + +## I got some run-time errors about invalid proto descriptors when using `ProtocolMessageEquals`. Help! + +{: .callout .note} +**Note:** `ProtocolMessageEquals` and `ProtocolMessageEquiv` are *deprecated* +now. Please use `EqualsProto`, etc instead. + +`ProtocolMessageEquals` and `ProtocolMessageEquiv` were redefined recently and +are now less tolerant of invalid protocol buffer definitions. In particular, if +you have a `foo.proto` that doesn't fully qualify the type of a protocol message +it references (e.g. `message` where it should be `message`), you +will now get run-time errors like: + +``` +... descriptor.cc:...] Invalid proto descriptor for file "path/to/foo.proto": +... descriptor.cc:...] blah.MyMessage.my_field: ".Bar" is not defined. +``` + +If you see this, your `.proto` file is broken and needs to be fixed by making +the types fully qualified. The new definition of `ProtocolMessageEquals` and +`ProtocolMessageEquiv` just happen to reveal your bug. + +## My death test modifies some state, but the change seems lost after the death test finishes. Why? + +Death tests (`EXPECT_DEATH`, etc) are executed in a sub-process s.t. the +expected crash won't kill the test program (i.e. the parent process). As a +result, any in-memory side effects they incur are observable in their respective +sub-processes, but not in the parent process. You can think of them as running +in a parallel universe, more or less. + +In particular, if you use mocking and the death test statement invokes some mock +methods, the parent process will think the calls have never occurred. Therefore, +you may want to move your `EXPECT_CALL` statements inside the `EXPECT_DEATH` +macro. + +## EXPECT_EQ(htonl(blah), blah_blah) generates weird compiler errors in opt mode. Is this a GoogleTest bug? + +Actually, the bug is in `htonl()`. + +According to `'man htonl'`, `htonl()` is a *function*, which means it's valid to +use `htonl` as a function pointer. However, in opt mode `htonl()` is defined as +a *macro*, which breaks this usage. + +Worse, the macro definition of `htonl()` uses a `gcc` extension and is *not* +standard C++. That hacky implementation has some ad hoc limitations. In +particular, it prevents you from writing `Foo()`, where `Foo` +is a template that has an integral argument. + +The implementation of `EXPECT_EQ(a, b)` uses `sizeof(... a ...)` inside a +template argument, and thus doesn't compile in opt mode when `a` contains a call +to `htonl()`. It is difficult to make `EXPECT_EQ` bypass the `htonl()` bug, as +the solution must work with different compilers on various platforms. + +## The compiler complains about "undefined references" to some static const member variables, but I did define them in the class body. What's wrong? + +If your class has a static data member: + +```c++ +// foo.h +class Foo { + ... + static const int kBar = 100; +}; +``` + +You also need to define it *outside* of the class body in `foo.cc`: + +```c++ +const int Foo::kBar; // No initializer here. +``` + +Otherwise your code is **invalid C++**, and may break in unexpected ways. In +particular, using it in GoogleTest comparison assertions (`EXPECT_EQ`, etc) will +generate an "undefined reference" linker error. The fact that "it used to work" +doesn't mean it's valid. It just means that you were lucky. :-) + +If the declaration of the static data member is `constexpr` then it is +implicitly an `inline` definition, and a separate definition in `foo.cc` is not +needed: + +```c++ +// foo.h +class Foo { + ... + static constexpr int kBar = 100; // Defines kBar, no need to do it in foo.cc. +}; +``` + +## Can I derive a test fixture from another? + +Yes. + +Each test fixture has a corresponding and same named test suite. This means only +one test suite can use a particular fixture. Sometimes, however, multiple test +cases may want to use the same or slightly different fixtures. For example, you +may want to make sure that all of a GUI library's test suites don't leak +important system resources like fonts and brushes. + +In GoogleTest, you share a fixture among test suites by putting the shared logic +in a base test fixture, then deriving from that base a separate fixture for each +test suite that wants to use this common logic. You then use `TEST_F()` to write +tests using each derived fixture. + +Typically, your code looks like this: + +```c++ +// Defines a base test fixture. +class BaseTest : public ::testing::Test { + protected: + ... +}; + +// Derives a fixture FooTest from BaseTest. +class FooTest : public BaseTest { + protected: + void SetUp() override { + BaseTest::SetUp(); // Sets up the base fixture first. + ... additional set-up work ... + } + + void TearDown() override { + ... clean-up work for FooTest ... + BaseTest::TearDown(); // Remember to tear down the base fixture + // after cleaning up FooTest! + } + + ... functions and variables for FooTest ... +}; + +// Tests that use the fixture FooTest. +TEST_F(FooTest, Bar) { ... } +TEST_F(FooTest, Baz) { ... } + +... additional fixtures derived from BaseTest ... +``` + +If necessary, you can continue to derive test fixtures from a derived fixture. +GoogleTest has no limit on how deep the hierarchy can be. + +For a complete example using derived test fixtures, see +[sample5_unittest.cc](https://github.com/google/googletest/blob/main/googletest/samples/sample5_unittest.cc). + +## My compiler complains "void value not ignored as it ought to be." What does this mean? + +You're probably using an `ASSERT_*()` in a function that doesn't return `void`. +`ASSERT_*()` can only be used in `void` functions, due to exceptions being +disabled by our build system. Please see more details +[here](advanced.md#assertion-placement). + +## My death test hangs (or seg-faults). How do I fix it? + +In GoogleTest, death tests are run in a child process and the way they work is +delicate. To write death tests you really need to understand how they work—see +the details at [Death Assertions](reference/assertions.md#death) in the +Assertions Reference. + +In particular, death tests don't like having multiple threads in the parent +process. So the first thing you can try is to eliminate creating threads outside +of `EXPECT_DEATH()`. For example, you may want to use mocks or fake objects +instead of real ones in your tests. + +Sometimes this is impossible as some library you must use may be creating +threads before `main()` is even reached. In this case, you can try to minimize +the chance of conflicts by either moving as many activities as possible inside +`EXPECT_DEATH()` (in the extreme case, you want to move everything inside), or +leaving as few things as possible in it. Also, you can try to set the death test +style to `"threadsafe"`, which is safer but slower, and see if it helps. + +If you go with thread-safe death tests, remember that they rerun the test +program from the beginning in the child process. Therefore make sure your +program can run side-by-side with itself and is deterministic. + +In the end, this boils down to good concurrent programming. You have to make +sure that there are no race conditions or deadlocks in your program. No silver +bullet - sorry! + +## Should I use the constructor/destructor of the test fixture or SetUp()/TearDown()? {#CtorVsSetUp} + +The first thing to remember is that GoogleTest does **not** reuse the same test +fixture object across multiple tests. For each `TEST_F`, GoogleTest will create +a **fresh** test fixture object, immediately call `SetUp()`, run the test body, +call `TearDown()`, and then delete the test fixture object. + +When you need to write per-test set-up and tear-down logic, you have the choice +between using the test fixture constructor/destructor or `SetUp()/TearDown()`. +The former is usually preferred, as it has the following benefits: + +* By initializing a member variable in the constructor, we have the option to + make it `const`, which helps prevent accidental changes to its value and + makes the tests more obviously correct. +* In case we need to subclass the test fixture class, the subclass' + constructor is guaranteed to call the base class' constructor *first*, and + the subclass' destructor is guaranteed to call the base class' destructor + *afterward*. With `SetUp()/TearDown()`, a subclass may make the mistake of + forgetting to call the base class' `SetUp()/TearDown()` or call them at the + wrong time. + +You may still want to use `SetUp()/TearDown()` in the following cases: + +* C++ does not allow virtual function calls in constructors and destructors. + You can call a method declared as virtual, but it will not use dynamic + dispatch. It will use the definition from the class the constructor of which + is currently executing. This is because calling a virtual method before the + derived class constructor has a chance to run is very dangerous - the + virtual method might operate on uninitialized data. Therefore, if you need + to call a method that will be overridden in a derived class, you have to use + `SetUp()/TearDown()`. +* In the body of a constructor (or destructor), it's not possible to use the + `ASSERT_xx` macros. Therefore, if the set-up operation could cause a fatal + test failure that should prevent the test from running, it's necessary to + use `abort` and abort the whole test + executable, or to use `SetUp()` instead of a constructor. +* If the tear-down operation could throw an exception, you must use + `TearDown()` as opposed to the destructor, as throwing in a destructor leads + to undefined behavior and usually will kill your program right away. Note + that many standard libraries (like STL) may throw when exceptions are + enabled in the compiler. Therefore you should prefer `TearDown()` if you + want to write portable tests that work with or without exceptions. +* The GoogleTest team is considering making the assertion macros throw on + platforms where exceptions are enabled (e.g. Windows, Mac OS, and Linux + client-side), which will eliminate the need for the user to propagate + failures from a subroutine to its caller. Therefore, you shouldn't use + GoogleTest assertions in a destructor if your code could run on such a + platform. + +## The compiler complains "no matching function to call" when I use ASSERT_PRED*. How do I fix it? + +See details for [`EXPECT_PRED*`](reference/assertions.md#EXPECT_PRED) in the +Assertions Reference. + +## My compiler complains about "ignoring return value" when I call RUN_ALL_TESTS(). Why? + +Some people had been ignoring the return value of `RUN_ALL_TESTS()`. That is, +instead of + +```c++ + return RUN_ALL_TESTS(); +``` + +they write + +```c++ + RUN_ALL_TESTS(); +``` + +This is **wrong and dangerous**. The testing services needs to see the return +value of `RUN_ALL_TESTS()` in order to determine if a test has passed. If your +`main()` function ignores it, your test will be considered successful even if it +has a GoogleTest assertion failure. Very bad. + +We have decided to fix this (thanks to Michael Chastain for the idea). Now, your +code will no longer be able to ignore `RUN_ALL_TESTS()` when compiled with +`gcc`. If you do so, you'll get a compiler error. + +If you see the compiler complaining about you ignoring the return value of +`RUN_ALL_TESTS()`, the fix is simple: just make sure its value is used as the +return value of `main()`. + +But how could we introduce a change that breaks existing tests? Well, in this +case, the code was already broken in the first place, so we didn't break it. :-) + +## My compiler complains that a constructor (or destructor) cannot return a value. What's going on? + +Due to a peculiarity of C++, in order to support the syntax for streaming +messages to an `ASSERT_*`, e.g. + +```c++ + ASSERT_EQ(1, Foo()) << "blah blah" << foo; +``` + +we had to give up using `ASSERT*` and `FAIL*` (but not `EXPECT*` and +`ADD_FAILURE*`) in constructors and destructors. The workaround is to move the +content of your constructor/destructor to a private void member function, or +switch to `EXPECT_*()` if that works. This +[section](advanced.md#assertion-placement) in the user's guide explains it. + +## My SetUp() function is not called. Why? + +C++ is case-sensitive. Did you spell it as `Setup()`? + +Similarly, sometimes people spell `SetUpTestSuite()` as `SetupTestSuite()` and +wonder why it's never called. + +## I have several test suites which share the same test fixture logic, do I have to define a new test fixture class for each of them? This seems pretty tedious. + +You don't have to. Instead of + +```c++ +class FooTest : public BaseTest {}; + +TEST_F(FooTest, Abc) { ... } +TEST_F(FooTest, Def) { ... } + +class BarTest : public BaseTest {}; + +TEST_F(BarTest, Abc) { ... } +TEST_F(BarTest, Def) { ... } +``` + +you can simply `typedef` the test fixtures: + +```c++ +typedef BaseTest FooTest; + +TEST_F(FooTest, Abc) { ... } +TEST_F(FooTest, Def) { ... } + +typedef BaseTest BarTest; + +TEST_F(BarTest, Abc) { ... } +TEST_F(BarTest, Def) { ... } +``` + +## GoogleTest output is buried in a whole bunch of LOG messages. What do I do? + +The GoogleTest output is meant to be a concise and human-friendly report. If +your test generates textual output itself, it will mix with the GoogleTest +output, making it hard to read. However, there is an easy solution to this +problem. + +Since `LOG` messages go to stderr, we decided to let GoogleTest output go to +stdout. This way, you can easily separate the two using redirection. For +example: + +```shell +$ ./my_test > gtest_output.txt +``` + +## Why should I prefer test fixtures over global variables? + +There are several good reasons: + +1. It's likely your test needs to change the states of its global variables. + This makes it difficult to keep side effects from escaping one test and + contaminating others, making debugging difficult. By using fixtures, each + test has a fresh set of variables that's different (but with the same + names). Thus, tests are kept independent of each other. +2. Global variables pollute the global namespace. +3. Test fixtures can be reused via subclassing, which cannot be done easily + with global variables. This is useful if many test suites have something in + common. + +## What can the statement argument in ASSERT_DEATH() be? + +`ASSERT_DEATH(statement, matcher)` (or any death assertion macro) can be used +wherever *`statement`* is valid. So basically *`statement`* can be any C++ +statement that makes sense in the current context. In particular, it can +reference global and/or local variables, and can be: + +* a simple function call (often the case), +* a complex expression, or +* a compound statement. + +Some examples are shown here: + +```c++ +// A death test can be a simple function call. +TEST(MyDeathTest, FunctionCall) { + ASSERT_DEATH(Xyz(5), "Xyz failed"); +} + +// Or a complex expression that references variables and functions. +TEST(MyDeathTest, ComplexExpression) { + const bool c = Condition(); + ASSERT_DEATH((c ? Func1(0) : object2.Method("test")), + "(Func1|Method) failed"); +} + +// Death assertions can be used anywhere in a function. In +// particular, they can be inside a loop. +TEST(MyDeathTest, InsideLoop) { + // Verifies that Foo(0), Foo(1), ..., and Foo(4) all die. + for (int i = 0; i < 5; i++) { + EXPECT_DEATH_M(Foo(i), "Foo has \\d+ errors", + ::testing::Message() << "where i is " << i); + } +} + +// A death assertion can contain a compound statement. +TEST(MyDeathTest, CompoundStatement) { + // Verifies that at lease one of Bar(0), Bar(1), ..., and + // Bar(4) dies. + ASSERT_DEATH({ + for (int i = 0; i < 5; i++) { + Bar(i); + } + }, + "Bar has \\d+ errors"); +} +``` + +## I have a fixture class `FooTest`, but `TEST_F(FooTest, Bar)` gives me error ``"no matching function for call to `FooTest::FooTest()'"``. Why? + +GoogleTest needs to be able to create objects of your test fixture class, so it +must have a default constructor. Normally the compiler will define one for you. +However, there are cases where you have to define your own: + +* If you explicitly declare a non-default constructor for class `FooTest` + (`DISALLOW_EVIL_CONSTRUCTORS()` does this), then you need to define a + default constructor, even if it would be empty. +* If `FooTest` has a const non-static data member, then you have to define the + default constructor *and* initialize the const member in the initializer + list of the constructor. (Early versions of `gcc` doesn't force you to + initialize the const member. It's a bug that has been fixed in `gcc 4`.) + +## Why does ASSERT_DEATH complain about previous threads that were already joined? + +With the Linux pthread library, there is no turning back once you cross the line +from a single thread to multiple threads. The first time you create a thread, a +manager thread is created in addition, so you get 3, not 2, threads. Later when +the thread you create joins the main thread, the thread count decrements by 1, +but the manager thread will never be killed, so you still have 2 threads, which +means you cannot safely run a death test. + +The new NPTL thread library doesn't suffer from this problem, as it doesn't +create a manager thread. However, if you don't control which machine your test +runs on, you shouldn't depend on this. + +## Why does GoogleTest require the entire test suite, instead of individual tests, to be named *DeathTest when it uses ASSERT_DEATH? + +GoogleTest does not interleave tests from different test suites. That is, it +runs all tests in one test suite first, and then runs all tests in the next test +suite, and so on. GoogleTest does this because it needs to set up a test suite +before the first test in it is run, and tear it down afterwards. Splitting up +the test case would require multiple set-up and tear-down processes, which is +inefficient and makes the semantics unclean. + +If we were to determine the order of tests based on test name instead of test +case name, then we would have a problem with the following situation: + +```c++ +TEST_F(FooTest, AbcDeathTest) { ... } +TEST_F(FooTest, Uvw) { ... } + +TEST_F(BarTest, DefDeathTest) { ... } +TEST_F(BarTest, Xyz) { ... } +``` + +Since `FooTest.AbcDeathTest` needs to run before `BarTest.Xyz`, and we don't +interleave tests from different test suites, we need to run all tests in the +`FooTest` case before running any test in the `BarTest` case. This contradicts +with the requirement to run `BarTest.DefDeathTest` before `FooTest.Uvw`. + +## But I don't like calling my entire test suite \*DeathTest when it contains both death tests and non-death tests. What do I do? + +You don't have to, but if you like, you may split up the test suite into +`FooTest` and `FooDeathTest`, where the names make it clear that they are +related: + +```c++ +class FooTest : public ::testing::Test { ... }; + +TEST_F(FooTest, Abc) { ... } +TEST_F(FooTest, Def) { ... } + +using FooDeathTest = FooTest; + +TEST_F(FooDeathTest, Uvw) { ... EXPECT_DEATH(...) ... } +TEST_F(FooDeathTest, Xyz) { ... ASSERT_DEATH(...) ... } +``` + +## GoogleTest prints the LOG messages in a death test's child process only when the test fails. How can I see the LOG messages when the death test succeeds? + +Printing the LOG messages generated by the statement inside `EXPECT_DEATH()` +makes it harder to search for real problems in the parent's log. Therefore, +GoogleTest only prints them when the death test has failed. + +If you really need to see such LOG messages, a workaround is to temporarily +break the death test (e.g. by changing the regex pattern it is expected to +match). Admittedly, this is a hack. We'll consider a more permanent solution +after the fork-and-exec-style death tests are implemented. + +## The compiler complains about `no match for 'operator<<'` when I use an assertion. What gives? + +If you use a user-defined type `FooType` in an assertion, you must make sure +there is an `std::ostream& operator<<(std::ostream&, const FooType&)` function +defined such that we can print a value of `FooType`. + +In addition, if `FooType` is declared in a name space, the `<<` operator also +needs to be defined in the *same* name space. See +[Tip of the Week #49](http://abseil.io/tips/49) for details. + +## How do I suppress the memory leak messages on Windows? + +Since the statically initialized GoogleTest singleton requires allocations on +the heap, the Visual C++ memory leak detector will report memory leaks at the +end of the program run. The easiest way to avoid this is to use the +`_CrtMemCheckpoint` and `_CrtMemDumpAllObjectsSince` calls to not report any +statically initialized heap objects. See MSDN for more details and additional +heap check/debug routines. + +## How can my code detect if it is running in a test? + +If you write code that sniffs whether it's running in a test and does different +things accordingly, you are leaking test-only logic into production code and +there is no easy way to ensure that the test-only code paths aren't run by +mistake in production. Such cleverness also leads to +[Heisenbugs](https://en.wikipedia.org/wiki/Heisenbug). Therefore we strongly +advise against the practice, and GoogleTest doesn't provide a way to do it. + +In general, the recommended way to cause the code to behave differently under +test is [Dependency Injection](http://en.wikipedia.org/wiki/Dependency_injection). You can inject +different functionality from the test and from the production code. Since your +production code doesn't link in the for-test logic at all (the +[`testonly`](http://docs.bazel.build/versions/master/be/common-definitions.html#common.testonly) attribute for BUILD targets helps to ensure +that), there is no danger in accidentally running it. + +However, if you *really*, *really*, *really* have no choice, and if you follow +the rule of ending your test program names with `_test`, you can use the +*horrible* hack of sniffing your executable name (`argv[0]` in `main()`) to know +whether the code is under test. + +## How do I temporarily disable a test? + +If you have a broken test that you cannot fix right away, you can add the +`DISABLED_` prefix to its name. This will exclude it from execution. This is +better than commenting out the code or using `#if 0`, as disabled tests are +still compiled (and thus won't rot). + +To include disabled tests in test execution, just invoke the test program with +the `--gtest_also_run_disabled_tests` flag. + +## Is it OK if I have two separate `TEST(Foo, Bar)` test methods defined in different namespaces? + +Yes. + +The rule is **all test methods in the same test suite must use the same fixture +class.** This means that the following is **allowed** because both tests use the +same fixture class (`::testing::Test`). + +```c++ +namespace foo { +TEST(CoolTest, DoSomething) { + SUCCEED(); +} +} // namespace foo + +namespace bar { +TEST(CoolTest, DoSomething) { + SUCCEED(); +} +} // namespace bar +``` + +However, the following code is **not allowed** and will produce a runtime error +from GoogleTest because the test methods are using different test fixture +classes with the same test suite name. + +```c++ +namespace foo { +class CoolTest : public ::testing::Test {}; // Fixture foo::CoolTest +TEST_F(CoolTest, DoSomething) { + SUCCEED(); +} +} // namespace foo + +namespace bar { +class CoolTest : public ::testing::Test {}; // Fixture: bar::CoolTest +TEST_F(CoolTest, DoSomething) { + SUCCEED(); +} +} // namespace bar +``` diff --git a/3rdparty/googletest-1.13.0/docs/gmock_cheat_sheet.md b/3rdparty/googletest-1.13.0/docs/gmock_cheat_sheet.md new file mode 100644 index 0000000000000000000000000000000000000000..2fb0403e616a79a46294a6dd5c8489f1f1dd2a78 --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/gmock_cheat_sheet.md @@ -0,0 +1,241 @@ +# gMock Cheat Sheet + +## Defining a Mock Class + +### Mocking a Normal Class {#MockClass} + +Given + +```cpp +class Foo { + public: + virtual ~Foo(); + virtual int GetSize() const = 0; + virtual string Describe(const char* name) = 0; + virtual string Describe(int type) = 0; + virtual bool Process(Bar elem, int count) = 0; +}; +``` + +(note that `~Foo()` **must** be virtual) we can define its mock as + +```cpp +#include "gmock/gmock.h" + +class MockFoo : public Foo { + public: + MOCK_METHOD(int, GetSize, (), (const, override)); + MOCK_METHOD(string, Describe, (const char* name), (override)); + MOCK_METHOD(string, Describe, (int type), (override)); + MOCK_METHOD(bool, Process, (Bar elem, int count), (override)); +}; +``` + +To create a "nice" mock, which ignores all uninteresting calls, a "naggy" mock, +which warns on all uninteresting calls, or a "strict" mock, which treats them as +failures: + +```cpp +using ::testing::NiceMock; +using ::testing::NaggyMock; +using ::testing::StrictMock; + +NiceMock nice_foo; // The type is a subclass of MockFoo. +NaggyMock naggy_foo; // The type is a subclass of MockFoo. +StrictMock strict_foo; // The type is a subclass of MockFoo. +``` + +{: .callout .note} +**Note:** A mock object is currently naggy by default. We may make it nice by +default in the future. + +### Mocking a Class Template {#MockTemplate} + +Class templates can be mocked just like any class. + +To mock + +```cpp +template +class StackInterface { + public: + virtual ~StackInterface(); + virtual int GetSize() const = 0; + virtual void Push(const Elem& x) = 0; +}; +``` + +(note that all member functions that are mocked, including `~StackInterface()` +**must** be virtual). + +```cpp +template +class MockStack : public StackInterface { + public: + MOCK_METHOD(int, GetSize, (), (const, override)); + MOCK_METHOD(void, Push, (const Elem& x), (override)); +}; +``` + +### Specifying Calling Conventions for Mock Functions + +If your mock function doesn't use the default calling convention, you can +specify it by adding `Calltype(convention)` to `MOCK_METHOD`'s 4th parameter. +For example, + +```cpp + MOCK_METHOD(bool, Foo, (int n), (Calltype(STDMETHODCALLTYPE))); + MOCK_METHOD(int, Bar, (double x, double y), + (const, Calltype(STDMETHODCALLTYPE))); +``` + +where `STDMETHODCALLTYPE` is defined by `` on Windows. + +## Using Mocks in Tests {#UsingMocks} + +The typical work flow is: + +1. Import the gMock names you need to use. All gMock symbols are in the + `testing` namespace unless they are macros or otherwise noted. +2. Create the mock objects. +3. Optionally, set the default actions of the mock objects. +4. Set your expectations on the mock objects (How will they be called? What + will they do?). +5. Exercise code that uses the mock objects; if necessary, check the result + using googletest assertions. +6. When a mock object is destructed, gMock automatically verifies that all + expectations on it have been satisfied. + +Here's an example: + +```cpp +using ::testing::Return; // #1 + +TEST(BarTest, DoesThis) { + MockFoo foo; // #2 + + ON_CALL(foo, GetSize()) // #3 + .WillByDefault(Return(1)); + // ... other default actions ... + + EXPECT_CALL(foo, Describe(5)) // #4 + .Times(3) + .WillRepeatedly(Return("Category 5")); + // ... other expectations ... + + EXPECT_EQ(MyProductionFunction(&foo), "good"); // #5 +} // #6 +``` + +## Setting Default Actions {#OnCall} + +gMock has a **built-in default action** for any function that returns `void`, +`bool`, a numeric value, or a pointer. In C++11, it will additionally returns +the default-constructed value, if one exists for the given type. + +To customize the default action for functions with return type `T`, use +[`DefaultValue`](reference/mocking.md#DefaultValue). For example: + +```cpp + // Sets the default action for return type std::unique_ptr to + // creating a new Buzz every time. + DefaultValue>::SetFactory( + [] { return std::make_unique(AccessLevel::kInternal); }); + + // When this fires, the default action of MakeBuzz() will run, which + // will return a new Buzz object. + EXPECT_CALL(mock_buzzer_, MakeBuzz("hello")).Times(AnyNumber()); + + auto buzz1 = mock_buzzer_.MakeBuzz("hello"); + auto buzz2 = mock_buzzer_.MakeBuzz("hello"); + EXPECT_NE(buzz1, nullptr); + EXPECT_NE(buzz2, nullptr); + EXPECT_NE(buzz1, buzz2); + + // Resets the default action for return type std::unique_ptr, + // to avoid interfere with other tests. + DefaultValue>::Clear(); +``` + +To customize the default action for a particular method of a specific mock +object, use [`ON_CALL`](reference/mocking.md#ON_CALL). `ON_CALL` has a similar +syntax to `EXPECT_CALL`, but it is used for setting default behaviors when you +do not require that the mock method is called. See +[Knowing When to Expect](gmock_cook_book.md#UseOnCall) for a more detailed +discussion. + +## Setting Expectations {#ExpectCall} + +See [`EXPECT_CALL`](reference/mocking.md#EXPECT_CALL) in the Mocking Reference. + +## Matchers {#MatcherList} + +See the [Matchers Reference](reference/matchers.md). + +## Actions {#ActionList} + +See the [Actions Reference](reference/actions.md). + +## Cardinalities {#CardinalityList} + +See the [`Times` clause](reference/mocking.md#EXPECT_CALL.Times) of +`EXPECT_CALL` in the Mocking Reference. + +## Expectation Order + +By default, expectations can be matched in *any* order. If some or all +expectations must be matched in a given order, you can use the +[`After` clause](reference/mocking.md#EXPECT_CALL.After) or +[`InSequence` clause](reference/mocking.md#EXPECT_CALL.InSequence) of +`EXPECT_CALL`, or use an [`InSequence` object](reference/mocking.md#InSequence). + +## Verifying and Resetting a Mock + +gMock will verify the expectations on a mock object when it is destructed, or +you can do it earlier: + +```cpp +using ::testing::Mock; +... +// Verifies and removes the expectations on mock_obj; +// returns true if and only if successful. +Mock::VerifyAndClearExpectations(&mock_obj); +... +// Verifies and removes the expectations on mock_obj; +// also removes the default actions set by ON_CALL(); +// returns true if and only if successful. +Mock::VerifyAndClear(&mock_obj); +``` + +Do not set new expectations after verifying and clearing a mock after its use. +Setting expectations after code that exercises the mock has undefined behavior. +See [Using Mocks in Tests](gmock_for_dummies.md#using-mocks-in-tests) for more +information. + +You can also tell gMock that a mock object can be leaked and doesn't need to be +verified: + +```cpp +Mock::AllowLeak(&mock_obj); +``` + +## Mock Classes + +gMock defines a convenient mock class template + +```cpp +class MockFunction { + public: + MOCK_METHOD(R, Call, (A1, ..., An)); +}; +``` + +See this [recipe](gmock_cook_book.md#UsingCheckPoints) for one application of +it. + +## Flags + +| Flag | Description | +| :----------------------------- | :---------------------------------------- | +| `--gmock_catch_leaked_mocks=0` | Don't report leaked mock objects as failures. | +| `--gmock_verbose=LEVEL` | Sets the default verbosity level (`info`, `warning`, or `error`) of Google Mock messages. | diff --git a/3rdparty/googletest-1.13.0/docs/gmock_cook_book.md b/3rdparty/googletest-1.13.0/docs/gmock_cook_book.md new file mode 100644 index 0000000000000000000000000000000000000000..fc7db35b82c769fea0c34cf875bf6529d58cbaab --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/gmock_cook_book.md @@ -0,0 +1,4343 @@ +# gMock Cookbook + +You can find recipes for using gMock here. If you haven't yet, please read +[the dummy guide](gmock_for_dummies.md) first to make sure you understand the +basics. + +{: .callout .note} +**Note:** gMock lives in the `testing` name space. For readability, it is +recommended to write `using ::testing::Foo;` once in your file before using the +name `Foo` defined by gMock. We omit such `using` statements in this section for +brevity, but you should do it in your own code. + +## Creating Mock Classes + +Mock classes are defined as normal classes, using the `MOCK_METHOD` macro to +generate mocked methods. The macro gets 3 or 4 parameters: + +```cpp +class MyMock { + public: + MOCK_METHOD(ReturnType, MethodName, (Args...)); + MOCK_METHOD(ReturnType, MethodName, (Args...), (Specs...)); +}; +``` + +The first 3 parameters are simply the method declaration, split into 3 parts. +The 4th parameter accepts a closed list of qualifiers, which affect the +generated method: + +* **`const`** - Makes the mocked method a `const` method. Required if + overriding a `const` method. +* **`override`** - Marks the method with `override`. Recommended if overriding + a `virtual` method. +* **`noexcept`** - Marks the method with `noexcept`. Required if overriding a + `noexcept` method. +* **`Calltype(...)`** - Sets the call type for the method (e.g. to + `STDMETHODCALLTYPE`), useful in Windows. +* **`ref(...)`** - Marks the method with the reference qualification + specified. Required if overriding a method that has reference + qualifications. Eg `ref(&)` or `ref(&&)`. + +### Dealing with unprotected commas + +Unprotected commas, i.e. commas which are not surrounded by parentheses, prevent +`MOCK_METHOD` from parsing its arguments correctly: + +{: .bad} +```cpp +class MockFoo { + public: + MOCK_METHOD(std::pair, GetPair, ()); // Won't compile! + MOCK_METHOD(bool, CheckMap, (std::map, bool)); // Won't compile! +}; +``` + +Solution 1 - wrap with parentheses: + +{: .good} +```cpp +class MockFoo { + public: + MOCK_METHOD((std::pair), GetPair, ()); + MOCK_METHOD(bool, CheckMap, ((std::map), bool)); +}; +``` + +Note that wrapping a return or argument type with parentheses is, in general, +invalid C++. `MOCK_METHOD` removes the parentheses. + +Solution 2 - define an alias: + +{: .good} +```cpp +class MockFoo { + public: + using BoolAndInt = std::pair; + MOCK_METHOD(BoolAndInt, GetPair, ()); + using MapIntDouble = std::map; + MOCK_METHOD(bool, CheckMap, (MapIntDouble, bool)); +}; +``` + +### Mocking Private or Protected Methods + +You must always put a mock method definition (`MOCK_METHOD`) in a `public:` +section of the mock class, regardless of the method being mocked being `public`, +`protected`, or `private` in the base class. This allows `ON_CALL` and +`EXPECT_CALL` to reference the mock function from outside of the mock class. +(Yes, C++ allows a subclass to change the access level of a virtual function in +the base class.) Example: + +```cpp +class Foo { + public: + ... + virtual bool Transform(Gadget* g) = 0; + + protected: + virtual void Resume(); + + private: + virtual int GetTimeOut(); +}; + +class MockFoo : public Foo { + public: + ... + MOCK_METHOD(bool, Transform, (Gadget* g), (override)); + + // The following must be in the public section, even though the + // methods are protected or private in the base class. + MOCK_METHOD(void, Resume, (), (override)); + MOCK_METHOD(int, GetTimeOut, (), (override)); +}; +``` + +### Mocking Overloaded Methods + +You can mock overloaded functions as usual. No special attention is required: + +```cpp +class Foo { + ... + + // Must be virtual as we'll inherit from Foo. + virtual ~Foo(); + + // Overloaded on the types and/or numbers of arguments. + virtual int Add(Element x); + virtual int Add(int times, Element x); + + // Overloaded on the const-ness of this object. + virtual Bar& GetBar(); + virtual const Bar& GetBar() const; +}; + +class MockFoo : public Foo { + ... + MOCK_METHOD(int, Add, (Element x), (override)); + MOCK_METHOD(int, Add, (int times, Element x), (override)); + + MOCK_METHOD(Bar&, GetBar, (), (override)); + MOCK_METHOD(const Bar&, GetBar, (), (const, override)); +}; +``` + +{: .callout .note} +**Note:** if you don't mock all versions of the overloaded method, the compiler +will give you a warning about some methods in the base class being hidden. To +fix that, use `using` to bring them in scope: + +```cpp +class MockFoo : public Foo { + ... + using Foo::Add; + MOCK_METHOD(int, Add, (Element x), (override)); + // We don't want to mock int Add(int times, Element x); + ... +}; +``` + +### Mocking Class Templates + +You can mock class templates just like any class. + +```cpp +template +class StackInterface { + ... + // Must be virtual as we'll inherit from StackInterface. + virtual ~StackInterface(); + + virtual int GetSize() const = 0; + virtual void Push(const Elem& x) = 0; +}; + +template +class MockStack : public StackInterface { + ... + MOCK_METHOD(int, GetSize, (), (override)); + MOCK_METHOD(void, Push, (const Elem& x), (override)); +}; +``` + +### Mocking Non-virtual Methods {#MockingNonVirtualMethods} + +gMock can mock non-virtual functions to be used in Hi-perf dependency injection. + +In this case, instead of sharing a common base class with the real class, your +mock class will be *unrelated* to the real class, but contain methods with the +same signatures. The syntax for mocking non-virtual methods is the *same* as +mocking virtual methods (just don't add `override`): + +```cpp +// A simple packet stream class. None of its members is virtual. +class ConcretePacketStream { + public: + void AppendPacket(Packet* new_packet); + const Packet* GetPacket(size_t packet_number) const; + size_t NumberOfPackets() const; + ... +}; + +// A mock packet stream class. It inherits from no other, but defines +// GetPacket() and NumberOfPackets(). +class MockPacketStream { + public: + MOCK_METHOD(const Packet*, GetPacket, (size_t packet_number), (const)); + MOCK_METHOD(size_t, NumberOfPackets, (), (const)); + ... +}; +``` + +Note that the mock class doesn't define `AppendPacket()`, unlike the real class. +That's fine as long as the test doesn't need to call it. + +Next, you need a way to say that you want to use `ConcretePacketStream` in +production code, and use `MockPacketStream` in tests. Since the functions are +not virtual and the two classes are unrelated, you must specify your choice at +*compile time* (as opposed to run time). + +One way to do it is to templatize your code that needs to use a packet stream. +More specifically, you will give your code a template type argument for the type +of the packet stream. In production, you will instantiate your template with +`ConcretePacketStream` as the type argument. In tests, you will instantiate the +same template with `MockPacketStream`. For example, you may write: + +```cpp +template +void CreateConnection(PacketStream* stream) { ... } + +template +class PacketReader { + public: + void ReadPackets(PacketStream* stream, size_t packet_num); +}; +``` + +Then you can use `CreateConnection()` and +`PacketReader` in production code, and use +`CreateConnection()` and `PacketReader` in +tests. + +```cpp + MockPacketStream mock_stream; + EXPECT_CALL(mock_stream, ...)...; + .. set more expectations on mock_stream ... + PacketReader reader(&mock_stream); + ... exercise reader ... +``` + +### Mocking Free Functions + +It is not possible to directly mock a free function (i.e. a C-style function or +a static method). If you need to, you can rewrite your code to use an interface +(abstract class). + +Instead of calling a free function (say, `OpenFile`) directly, introduce an +interface for it and have a concrete subclass that calls the free function: + +```cpp +class FileInterface { + public: + ... + virtual bool Open(const char* path, const char* mode) = 0; +}; + +class File : public FileInterface { + public: + ... + bool Open(const char* path, const char* mode) override { + return OpenFile(path, mode); + } +}; +``` + +Your code should talk to `FileInterface` to open a file. Now it's easy to mock +out the function. + +This may seem like a lot of hassle, but in practice you often have multiple +related functions that you can put in the same interface, so the per-function +syntactic overhead will be much lower. + +If you are concerned about the performance overhead incurred by virtual +functions, and profiling confirms your concern, you can combine this with the +recipe for [mocking non-virtual methods](#MockingNonVirtualMethods). + +### Old-Style `MOCK_METHODn` Macros + +Before the generic `MOCK_METHOD` macro +[was introduced in 2018](https://github.com/google/googletest/commit/c5f08bf91944ce1b19bcf414fa1760e69d20afc2), +mocks where created using a family of macros collectively called `MOCK_METHODn`. +These macros are still supported, though migration to the new `MOCK_METHOD` is +recommended. + +The macros in the `MOCK_METHODn` family differ from `MOCK_METHOD`: + +* The general structure is `MOCK_METHODn(MethodName, ReturnType(Args))`, + instead of `MOCK_METHOD(ReturnType, MethodName, (Args))`. +* The number `n` must equal the number of arguments. +* When mocking a const method, one must use `MOCK_CONST_METHODn`. +* When mocking a class template, the macro name must be suffixed with `_T`. +* In order to specify the call type, the macro name must be suffixed with + `_WITH_CALLTYPE`, and the call type is the first macro argument. + +Old macros and their new equivalents: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Simple
OldMOCK_METHOD1(Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int))
Const Method
OldMOCK_CONST_METHOD1(Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (const))
Method in a Class Template
OldMOCK_METHOD1_T(Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int))
Const Method in a Class Template
OldMOCK_CONST_METHOD1_T(Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (const))
Method with Call Type
OldMOCK_METHOD1_WITH_CALLTYPE(STDMETHODCALLTYPE, Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (Calltype(STDMETHODCALLTYPE)))
Const Method with Call Type
OldMOCK_CONST_METHOD1_WITH_CALLTYPE(STDMETHODCALLTYPE, Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (const, Calltype(STDMETHODCALLTYPE)))
Method with Call Type in a Class Template
OldMOCK_METHOD1_T_WITH_CALLTYPE(STDMETHODCALLTYPE, Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (Calltype(STDMETHODCALLTYPE)))
Const Method with Call Type in a Class Template
OldMOCK_CONST_METHOD1_T_WITH_CALLTYPE(STDMETHODCALLTYPE, Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (const, Calltype(STDMETHODCALLTYPE)))
+ +### The Nice, the Strict, and the Naggy {#NiceStrictNaggy} + +If a mock method has no `EXPECT_CALL` spec but is called, we say that it's an +"uninteresting call", and the default action (which can be specified using +`ON_CALL()`) of the method will be taken. Currently, an uninteresting call will +also by default cause gMock to print a warning. + +However, sometimes you may want to ignore these uninteresting calls, and +sometimes you may want to treat them as errors. gMock lets you make the decision +on a per-mock-object basis. + +Suppose your test uses a mock class `MockFoo`: + +```cpp +TEST(...) { + MockFoo mock_foo; + EXPECT_CALL(mock_foo, DoThis()); + ... code that uses mock_foo ... +} +``` + +If a method of `mock_foo` other than `DoThis()` is called, you will get a +warning. However, if you rewrite your test to use `NiceMock` instead, +you can suppress the warning: + +```cpp +using ::testing::NiceMock; + +TEST(...) { + NiceMock mock_foo; + EXPECT_CALL(mock_foo, DoThis()); + ... code that uses mock_foo ... +} +``` + +`NiceMock` is a subclass of `MockFoo`, so it can be used wherever +`MockFoo` is accepted. + +It also works if `MockFoo`'s constructor takes some arguments, as +`NiceMock` "inherits" `MockFoo`'s constructors: + +```cpp +using ::testing::NiceMock; + +TEST(...) { + NiceMock mock_foo(5, "hi"); // Calls MockFoo(5, "hi"). + EXPECT_CALL(mock_foo, DoThis()); + ... code that uses mock_foo ... +} +``` + +The usage of `StrictMock` is similar, except that it makes all uninteresting +calls failures: + +```cpp +using ::testing::StrictMock; + +TEST(...) { + StrictMock mock_foo; + EXPECT_CALL(mock_foo, DoThis()); + ... code that uses mock_foo ... + + // The test will fail if a method of mock_foo other than DoThis() + // is called. +} +``` + +{: .callout .note} +NOTE: `NiceMock` and `StrictMock` only affects *uninteresting* calls (calls of +*methods* with no expectations); they do not affect *unexpected* calls (calls of +methods with expectations, but they don't match). See +[Understanding Uninteresting vs Unexpected Calls](#uninteresting-vs-unexpected). + +There are some caveats though (sadly they are side effects of C++'s +limitations): + +1. `NiceMock` and `StrictMock` only work for mock methods + defined using the `MOCK_METHOD` macro **directly** in the `MockFoo` class. + If a mock method is defined in a **base class** of `MockFoo`, the "nice" or + "strict" modifier may not affect it, depending on the compiler. In + particular, nesting `NiceMock` and `StrictMock` (e.g. + `NiceMock >`) is **not** supported. +2. `NiceMock` and `StrictMock` may not work correctly if the + destructor of `MockFoo` is not virtual. We would like to fix this, but it + requires cleaning up existing tests. + +Finally, you should be **very cautious** about when to use naggy or strict +mocks, as they tend to make tests more brittle and harder to maintain. When you +refactor your code without changing its externally visible behavior, ideally you +shouldn't need to update any tests. If your code interacts with a naggy mock, +however, you may start to get spammed with warnings as the result of your +change. Worse, if your code interacts with a strict mock, your tests may start +to fail and you'll be forced to fix them. Our general recommendation is to use +nice mocks (not yet the default) most of the time, use naggy mocks (the current +default) when developing or debugging tests, and use strict mocks only as the +last resort. + +### Simplifying the Interface without Breaking Existing Code {#SimplerInterfaces} + +Sometimes a method has a long list of arguments that is mostly uninteresting. +For example: + +```cpp +class LogSink { + public: + ... + virtual void send(LogSeverity severity, const char* full_filename, + const char* base_filename, int line, + const struct tm* tm_time, + const char* message, size_t message_len) = 0; +}; +``` + +This method's argument list is lengthy and hard to work with (the `message` +argument is not even 0-terminated). If we mock it as is, using the mock will be +awkward. If, however, we try to simplify this interface, we'll need to fix all +clients depending on it, which is often infeasible. + +The trick is to redispatch the method in the mock class: + +```cpp +class ScopedMockLog : public LogSink { + public: + ... + void send(LogSeverity severity, const char* full_filename, + const char* base_filename, int line, const tm* tm_time, + const char* message, size_t message_len) override { + // We are only interested in the log severity, full file name, and + // log message. + Log(severity, full_filename, std::string(message, message_len)); + } + + // Implements the mock method: + // + // void Log(LogSeverity severity, + // const string& file_path, + // const string& message); + MOCK_METHOD(void, Log, + (LogSeverity severity, const string& file_path, + const string& message)); +}; +``` + +By defining a new mock method with a trimmed argument list, we make the mock +class more user-friendly. + +This technique may also be applied to make overloaded methods more amenable to +mocking. For example, when overloads have been used to implement default +arguments: + +```cpp +class MockTurtleFactory : public TurtleFactory { + public: + Turtle* MakeTurtle(int length, int weight) override { ... } + Turtle* MakeTurtle(int length, int weight, int speed) override { ... } + + // the above methods delegate to this one: + MOCK_METHOD(Turtle*, DoMakeTurtle, ()); +}; +``` + +This allows tests that don't care which overload was invoked to avoid specifying +argument matchers: + +```cpp +ON_CALL(factory, DoMakeTurtle) + .WillByDefault(Return(MakeMockTurtle())); +``` + +### Alternative to Mocking Concrete Classes + +Often you may find yourself using classes that don't implement interfaces. In +order to test your code that uses such a class (let's call it `Concrete`), you +may be tempted to make the methods of `Concrete` virtual and then mock it. + +Try not to do that. + +Making a non-virtual function virtual is a big decision. It creates an extension +point where subclasses can tweak your class' behavior. This weakens your control +on the class because now it's harder to maintain the class invariants. You +should make a function virtual only when there is a valid reason for a subclass +to override it. + +Mocking concrete classes directly is problematic as it creates a tight coupling +between the class and the tests - any small change in the class may invalidate +your tests and make test maintenance a pain. + +To avoid such problems, many programmers have been practicing "coding to +interfaces": instead of talking to the `Concrete` class, your code would define +an interface and talk to it. Then you implement that interface as an adaptor on +top of `Concrete`. In tests, you can easily mock that interface to observe how +your code is doing. + +This technique incurs some overhead: + +* You pay the cost of virtual function calls (usually not a problem). +* There is more abstraction for the programmers to learn. + +However, it can also bring significant benefits in addition to better +testability: + +* `Concrete`'s API may not fit your problem domain very well, as you may not + be the only client it tries to serve. By designing your own interface, you + have a chance to tailor it to your need - you may add higher-level + functionalities, rename stuff, etc instead of just trimming the class. This + allows you to write your code (user of the interface) in a more natural way, + which means it will be more readable, more maintainable, and you'll be more + productive. +* If `Concrete`'s implementation ever has to change, you don't have to rewrite + everywhere it is used. Instead, you can absorb the change in your + implementation of the interface, and your other code and tests will be + insulated from this change. + +Some people worry that if everyone is practicing this technique, they will end +up writing lots of redundant code. This concern is totally understandable. +However, there are two reasons why it may not be the case: + +* Different projects may need to use `Concrete` in different ways, so the best + interfaces for them will be different. Therefore, each of them will have its + own domain-specific interface on top of `Concrete`, and they will not be the + same code. +* If enough projects want to use the same interface, they can always share it, + just like they have been sharing `Concrete`. You can check in the interface + and the adaptor somewhere near `Concrete` (perhaps in a `contrib` + sub-directory) and let many projects use it. + +You need to weigh the pros and cons carefully for your particular problem, but +I'd like to assure you that the Java community has been practicing this for a +long time and it's a proven effective technique applicable in a wide variety of +situations. :-) + +### Delegating Calls to a Fake {#DelegatingToFake} + +Some times you have a non-trivial fake implementation of an interface. For +example: + +```cpp +class Foo { + public: + virtual ~Foo() {} + virtual char DoThis(int n) = 0; + virtual void DoThat(const char* s, int* p) = 0; +}; + +class FakeFoo : public Foo { + public: + char DoThis(int n) override { + return (n > 0) ? '+' : + (n < 0) ? '-' : '0'; + } + + void DoThat(const char* s, int* p) override { + *p = strlen(s); + } +}; +``` + +Now you want to mock this interface such that you can set expectations on it. +However, you also want to use `FakeFoo` for the default behavior, as duplicating +it in the mock object is, well, a lot of work. + +When you define the mock class using gMock, you can have it delegate its default +action to a fake class you already have, using this pattern: + +```cpp +class MockFoo : public Foo { + public: + // Normal mock method definitions using gMock. + MOCK_METHOD(char, DoThis, (int n), (override)); + MOCK_METHOD(void, DoThat, (const char* s, int* p), (override)); + + // Delegates the default actions of the methods to a FakeFoo object. + // This must be called *before* the custom ON_CALL() statements. + void DelegateToFake() { + ON_CALL(*this, DoThis).WillByDefault([this](int n) { + return fake_.DoThis(n); + }); + ON_CALL(*this, DoThat).WillByDefault([this](const char* s, int* p) { + fake_.DoThat(s, p); + }); + } + + private: + FakeFoo fake_; // Keeps an instance of the fake in the mock. +}; +``` + +With that, you can use `MockFoo` in your tests as usual. Just remember that if +you don't explicitly set an action in an `ON_CALL()` or `EXPECT_CALL()`, the +fake will be called upon to do it.: + +```cpp +using ::testing::_; + +TEST(AbcTest, Xyz) { + MockFoo foo; + + foo.DelegateToFake(); // Enables the fake for delegation. + + // Put your ON_CALL(foo, ...)s here, if any. + + // No action specified, meaning to use the default action. + EXPECT_CALL(foo, DoThis(5)); + EXPECT_CALL(foo, DoThat(_, _)); + + int n = 0; + EXPECT_EQ('+', foo.DoThis(5)); // FakeFoo::DoThis() is invoked. + foo.DoThat("Hi", &n); // FakeFoo::DoThat() is invoked. + EXPECT_EQ(2, n); +} +``` + +**Some tips:** + +* If you want, you can still override the default action by providing your own + `ON_CALL()` or using `.WillOnce()` / `.WillRepeatedly()` in `EXPECT_CALL()`. +* In `DelegateToFake()`, you only need to delegate the methods whose fake + implementation you intend to use. + +* The general technique discussed here works for overloaded methods, but + you'll need to tell the compiler which version you mean. To disambiguate a + mock function (the one you specify inside the parentheses of `ON_CALL()`), + use [this technique](#SelectOverload); to disambiguate a fake function (the + one you place inside `Invoke()`), use a `static_cast` to specify the + function's type. For instance, if class `Foo` has methods `char DoThis(int + n)` and `bool DoThis(double x) const`, and you want to invoke the latter, + you need to write `Invoke(&fake_, static_cast(&FakeFoo::DoThis))` instead of `Invoke(&fake_, &FakeFoo::DoThis)` + (The strange-looking thing inside the angled brackets of `static_cast` is + the type of a function pointer to the second `DoThis()` method.). + +* Having to mix a mock and a fake is often a sign of something gone wrong. + Perhaps you haven't got used to the interaction-based way of testing yet. Or + perhaps your interface is taking on too many roles and should be split up. + Therefore, **don't abuse this**. We would only recommend to do it as an + intermediate step when you are refactoring your code. + +Regarding the tip on mixing a mock and a fake, here's an example on why it may +be a bad sign: Suppose you have a class `System` for low-level system +operations. In particular, it does file and I/O operations. And suppose you want +to test how your code uses `System` to do I/O, and you just want the file +operations to work normally. If you mock out the entire `System` class, you'll +have to provide a fake implementation for the file operation part, which +suggests that `System` is taking on too many roles. + +Instead, you can define a `FileOps` interface and an `IOOps` interface and split +`System`'s functionalities into the two. Then you can mock `IOOps` without +mocking `FileOps`. + +### Delegating Calls to a Real Object + +When using testing doubles (mocks, fakes, stubs, and etc), sometimes their +behaviors will differ from those of the real objects. This difference could be +either intentional (as in simulating an error such that you can test the error +handling code) or unintentional. If your mocks have different behaviors than the +real objects by mistake, you could end up with code that passes the tests but +fails in production. + +You can use the *delegating-to-real* technique to ensure that your mock has the +same behavior as the real object while retaining the ability to validate calls. +This technique is very similar to the [delegating-to-fake](#DelegatingToFake) +technique, the difference being that we use a real object instead of a fake. +Here's an example: + +```cpp +using ::testing::AtLeast; + +class MockFoo : public Foo { + public: + MockFoo() { + // By default, all calls are delegated to the real object. + ON_CALL(*this, DoThis).WillByDefault([this](int n) { + return real_.DoThis(n); + }); + ON_CALL(*this, DoThat).WillByDefault([this](const char* s, int* p) { + real_.DoThat(s, p); + }); + ... + } + MOCK_METHOD(char, DoThis, ...); + MOCK_METHOD(void, DoThat, ...); + ... + private: + Foo real_; +}; + +... + MockFoo mock; + EXPECT_CALL(mock, DoThis()) + .Times(3); + EXPECT_CALL(mock, DoThat("Hi")) + .Times(AtLeast(1)); + ... use mock in test ... +``` + +With this, gMock will verify that your code made the right calls (with the right +arguments, in the right order, called the right number of times, etc), and a +real object will answer the calls (so the behavior will be the same as in +production). This gives you the best of both worlds. + +### Delegating Calls to a Parent Class + +Ideally, you should code to interfaces, whose methods are all pure virtual. In +reality, sometimes you do need to mock a virtual method that is not pure (i.e, +it already has an implementation). For example: + +```cpp +class Foo { + public: + virtual ~Foo(); + + virtual void Pure(int n) = 0; + virtual int Concrete(const char* str) { ... } +}; + +class MockFoo : public Foo { + public: + // Mocking a pure method. + MOCK_METHOD(void, Pure, (int n), (override)); + // Mocking a concrete method. Foo::Concrete() is shadowed. + MOCK_METHOD(int, Concrete, (const char* str), (override)); +}; +``` + +Sometimes you may want to call `Foo::Concrete()` instead of +`MockFoo::Concrete()`. Perhaps you want to do it as part of a stub action, or +perhaps your test doesn't need to mock `Concrete()` at all (but it would be +oh-so painful to have to define a new mock class whenever you don't need to mock +one of its methods). + +You can call `Foo::Concrete()` inside an action by: + +```cpp +... + EXPECT_CALL(foo, Concrete).WillOnce([&foo](const char* str) { + return foo.Foo::Concrete(str); + }); +``` + +or tell the mock object that you don't want to mock `Concrete()`: + +```cpp +... + ON_CALL(foo, Concrete).WillByDefault([&foo](const char* str) { + return foo.Foo::Concrete(str); + }); +``` + +(Why don't we just write `{ return foo.Concrete(str); }`? If you do that, +`MockFoo::Concrete()` will be called (and cause an infinite recursion) since +`Foo::Concrete()` is virtual. That's just how C++ works.) + +## Using Matchers + +### Matching Argument Values Exactly + +You can specify exactly which arguments a mock method is expecting: + +```cpp +using ::testing::Return; +... + EXPECT_CALL(foo, DoThis(5)) + .WillOnce(Return('a')); + EXPECT_CALL(foo, DoThat("Hello", bar)); +``` + +### Using Simple Matchers + +You can use matchers to match arguments that have a certain property: + +```cpp +using ::testing::NotNull; +using ::testing::Return; +... + EXPECT_CALL(foo, DoThis(Ge(5))) // The argument must be >= 5. + .WillOnce(Return('a')); + EXPECT_CALL(foo, DoThat("Hello", NotNull())); + // The second argument must not be NULL. +``` + +A frequently used matcher is `_`, which matches anything: + +```cpp + EXPECT_CALL(foo, DoThat(_, NotNull())); +``` + +### Combining Matchers {#CombiningMatchers} + +You can build complex matchers from existing ones using `AllOf()`, +`AllOfArray()`, `AnyOf()`, `AnyOfArray()` and `Not()`: + +```cpp +using ::testing::AllOf; +using ::testing::Gt; +using ::testing::HasSubstr; +using ::testing::Ne; +using ::testing::Not; +... + // The argument must be > 5 and != 10. + EXPECT_CALL(foo, DoThis(AllOf(Gt(5), + Ne(10)))); + + // The first argument must not contain sub-string "blah". + EXPECT_CALL(foo, DoThat(Not(HasSubstr("blah")), + NULL)); +``` + +Matchers are function objects, and parametrized matchers can be composed just +like any other function. However because their types can be long and rarely +provide meaningful information, it can be easier to express them with C++14 +generic lambdas to avoid specifying types. For example, + +```cpp +using ::testing::Contains; +using ::testing::Property; + +inline constexpr auto HasFoo = [](const auto& f) { + return Property("foo", &MyClass::foo, Contains(f)); +}; +... + EXPECT_THAT(x, HasFoo("blah")); +``` + +### Casting Matchers {#SafeMatcherCast} + +gMock matchers are statically typed, meaning that the compiler can catch your +mistake if you use a matcher of the wrong type (for example, if you use `Eq(5)` +to match a `string` argument). Good for you! + +Sometimes, however, you know what you're doing and want the compiler to give you +some slack. One example is that you have a matcher for `long` and the argument +you want to match is `int`. While the two types aren't exactly the same, there +is nothing really wrong with using a `Matcher` to match an `int` - after +all, we can first convert the `int` argument to a `long` losslessly before +giving it to the matcher. + +To support this need, gMock gives you the `SafeMatcherCast(m)` function. It +casts a matcher `m` to type `Matcher`. To ensure safety, gMock checks that +(let `U` be the type `m` accepts : + +1. Type `T` can be *implicitly* cast to type `U`; +2. When both `T` and `U` are built-in arithmetic types (`bool`, integers, and + floating-point numbers), the conversion from `T` to `U` is not lossy (in + other words, any value representable by `T` can also be represented by `U`); + and +3. When `U` is a reference, `T` must also be a reference (as the underlying + matcher may be interested in the address of the `U` value). + +The code won't compile if any of these conditions isn't met. + +Here's one example: + +```cpp +using ::testing::SafeMatcherCast; + +// A base class and a child class. +class Base { ... }; +class Derived : public Base { ... }; + +class MockFoo : public Foo { + public: + MOCK_METHOD(void, DoThis, (Derived* derived), (override)); +}; + +... + MockFoo foo; + // m is a Matcher we got from somewhere. + EXPECT_CALL(foo, DoThis(SafeMatcherCast(m))); +``` + +If you find `SafeMatcherCast(m)` too limiting, you can use a similar function +`MatcherCast(m)`. The difference is that `MatcherCast` works as long as you +can `static_cast` type `T` to type `U`. + +`MatcherCast` essentially lets you bypass C++'s type system (`static_cast` isn't +always safe as it could throw away information, for example), so be careful not +to misuse/abuse it. + +### Selecting Between Overloaded Functions {#SelectOverload} + +If you expect an overloaded function to be called, the compiler may need some +help on which overloaded version it is. + +To disambiguate functions overloaded on the const-ness of this object, use the +`Const()` argument wrapper. + +```cpp +using ::testing::ReturnRef; + +class MockFoo : public Foo { + ... + MOCK_METHOD(Bar&, GetBar, (), (override)); + MOCK_METHOD(const Bar&, GetBar, (), (const, override)); +}; + +... + MockFoo foo; + Bar bar1, bar2; + EXPECT_CALL(foo, GetBar()) // The non-const GetBar(). + .WillOnce(ReturnRef(bar1)); + EXPECT_CALL(Const(foo), GetBar()) // The const GetBar(). + .WillOnce(ReturnRef(bar2)); +``` + +(`Const()` is defined by gMock and returns a `const` reference to its argument.) + +To disambiguate overloaded functions with the same number of arguments but +different argument types, you may need to specify the exact type of a matcher, +either by wrapping your matcher in `Matcher()`, or using a matcher whose +type is fixed (`TypedEq`, `An()`, etc): + +```cpp +using ::testing::An; +using ::testing::Matcher; +using ::testing::TypedEq; + +class MockPrinter : public Printer { + public: + MOCK_METHOD(void, Print, (int n), (override)); + MOCK_METHOD(void, Print, (char c), (override)); +}; + +TEST(PrinterTest, Print) { + MockPrinter printer; + + EXPECT_CALL(printer, Print(An())); // void Print(int); + EXPECT_CALL(printer, Print(Matcher(Lt(5)))); // void Print(int); + EXPECT_CALL(printer, Print(TypedEq('a'))); // void Print(char); + + printer.Print(3); + printer.Print(6); + printer.Print('a'); +} +``` + +### Performing Different Actions Based on the Arguments + +When a mock method is called, the *last* matching expectation that's still +active will be selected (think "newer overrides older"). So, you can make a +method do different things depending on its argument values like this: + +```cpp +using ::testing::_; +using ::testing::Lt; +using ::testing::Return; +... + // The default case. + EXPECT_CALL(foo, DoThis(_)) + .WillRepeatedly(Return('b')); + // The more specific case. + EXPECT_CALL(foo, DoThis(Lt(5))) + .WillRepeatedly(Return('a')); +``` + +Now, if `foo.DoThis()` is called with a value less than 5, `'a'` will be +returned; otherwise `'b'` will be returned. + +### Matching Multiple Arguments as a Whole + +Sometimes it's not enough to match the arguments individually. For example, we +may want to say that the first argument must be less than the second argument. +The `With()` clause allows us to match all arguments of a mock function as a +whole. For example, + +```cpp +using ::testing::_; +using ::testing::Ne; +using ::testing::Lt; +... + EXPECT_CALL(foo, InRange(Ne(0), _)) + .With(Lt()); +``` + +says that the first argument of `InRange()` must not be 0, and must be less than +the second argument. + +The expression inside `With()` must be a matcher of type `Matcher>`, where `A1`, ..., `An` are the types of the function arguments. + +You can also write `AllArgs(m)` instead of `m` inside `.With()`. The two forms +are equivalent, but `.With(AllArgs(Lt()))` is more readable than `.With(Lt())`. + +You can use `Args(m)` to match the `n` selected arguments (as a +tuple) against `m`. For example, + +```cpp +using ::testing::_; +using ::testing::AllOf; +using ::testing::Args; +using ::testing::Lt; +... + EXPECT_CALL(foo, Blah) + .With(AllOf(Args<0, 1>(Lt()), Args<1, 2>(Lt()))); +``` + +says that `Blah` will be called with arguments `x`, `y`, and `z` where `x < y < +z`. Note that in this example, it wasn't necessary to specify the positional +matchers. + +As a convenience and example, gMock provides some matchers for 2-tuples, +including the `Lt()` matcher above. See +[Multi-argument Matchers](reference/matchers.md#MultiArgMatchers) for the +complete list. + +Note that if you want to pass the arguments to a predicate of your own (e.g. +`.With(Args<0, 1>(Truly(&MyPredicate)))`), that predicate MUST be written to +take a `std::tuple` as its argument; gMock will pass the `n` selected arguments +as *one* single tuple to the predicate. + +### Using Matchers as Predicates + +Have you noticed that a matcher is just a fancy predicate that also knows how to +describe itself? Many existing algorithms take predicates as arguments (e.g. +those defined in STL's `` header), and it would be a shame if gMock +matchers were not allowed to participate. + +Luckily, you can use a matcher where a unary predicate functor is expected by +wrapping it inside the `Matches()` function. For example, + +```cpp +#include +#include + +using ::testing::Matches; +using ::testing::Ge; + +vector v; +... +// How many elements in v are >= 10? +const int count = count_if(v.begin(), v.end(), Matches(Ge(10))); +``` + +Since you can build complex matchers from simpler ones easily using gMock, this +gives you a way to conveniently construct composite predicates (doing the same +using STL's `` header is just painful). For example, here's a +predicate that's satisfied by any number that is >= 0, <= 100, and != 50: + +```cpp +using testing::AllOf; +using testing::Ge; +using testing::Le; +using testing::Matches; +using testing::Ne; +... +Matches(AllOf(Ge(0), Le(100), Ne(50))) +``` + +### Using Matchers in googletest Assertions + +See [`EXPECT_THAT`](reference/assertions.md#EXPECT_THAT) in the Assertions +Reference. + +### Using Predicates as Matchers + +gMock provides a set of built-in matchers for matching arguments with expected +values—see the [Matchers Reference](reference/matchers.md) for more information. +In case you find the built-in set lacking, you can use an arbitrary unary +predicate function or functor as a matcher - as long as the predicate accepts a +value of the type you want. You do this by wrapping the predicate inside the +`Truly()` function, for example: + +```cpp +using ::testing::Truly; + +int IsEven(int n) { return (n % 2) == 0 ? 1 : 0; } +... + // Bar() must be called with an even number. + EXPECT_CALL(foo, Bar(Truly(IsEven))); +``` + +Note that the predicate function / functor doesn't have to return `bool`. It +works as long as the return value can be used as the condition in the statement +`if (condition) ...`. + +### Matching Arguments that Are Not Copyable + +When you do an `EXPECT_CALL(mock_obj, Foo(bar))`, gMock saves away a copy of +`bar`. When `Foo()` is called later, gMock compares the argument to `Foo()` with +the saved copy of `bar`. This way, you don't need to worry about `bar` being +modified or destroyed after the `EXPECT_CALL()` is executed. The same is true +when you use matchers like `Eq(bar)`, `Le(bar)`, and so on. + +But what if `bar` cannot be copied (i.e. has no copy constructor)? You could +define your own matcher function or callback and use it with `Truly()`, as the +previous couple of recipes have shown. Or, you may be able to get away from it +if you can guarantee that `bar` won't be changed after the `EXPECT_CALL()` is +executed. Just tell gMock that it should save a reference to `bar`, instead of a +copy of it. Here's how: + +```cpp +using ::testing::Eq; +using ::testing::Lt; +... + // Expects that Foo()'s argument == bar. + EXPECT_CALL(mock_obj, Foo(Eq(std::ref(bar)))); + + // Expects that Foo()'s argument < bar. + EXPECT_CALL(mock_obj, Foo(Lt(std::ref(bar)))); +``` + +Remember: if you do this, don't change `bar` after the `EXPECT_CALL()`, or the +result is undefined. + +### Validating a Member of an Object + +Often a mock function takes a reference to object as an argument. When matching +the argument, you may not want to compare the entire object against a fixed +object, as that may be over-specification. Instead, you may need to validate a +certain member variable or the result of a certain getter method of the object. +You can do this with `Field()` and `Property()`. More specifically, + +```cpp +Field(&Foo::bar, m) +``` + +is a matcher that matches a `Foo` object whose `bar` member variable satisfies +matcher `m`. + +```cpp +Property(&Foo::baz, m) +``` + +is a matcher that matches a `Foo` object whose `baz()` method returns a value +that satisfies matcher `m`. + +For example: + +| Expression | Description | +| :--------------------------- | :--------------------------------------- | +| `Field(&Foo::number, Ge(3))` | Matches `x` where `x.number >= 3`. | +| `Property(&Foo::name, StartsWith("John "))` | Matches `x` where `x.name()` starts with `"John "`. | + +Note that in `Property(&Foo::baz, ...)`, method `baz()` must take no argument +and be declared as `const`. Don't use `Property()` against member functions that +you do not own, because taking addresses of functions is fragile and generally +not part of the contract of the function. + +`Field()` and `Property()` can also match plain pointers to objects. For +instance, + +```cpp +using ::testing::Field; +using ::testing::Ge; +... +Field(&Foo::number, Ge(3)) +``` + +matches a plain pointer `p` where `p->number >= 3`. If `p` is `NULL`, the match +will always fail regardless of the inner matcher. + +What if you want to validate more than one members at the same time? Remember +that there are [`AllOf()` and `AllOfArray()`](#CombiningMatchers). + +Finally `Field()` and `Property()` provide overloads that take the field or +property names as the first argument to include it in the error message. This +can be useful when creating combined matchers. + +```cpp +using ::testing::AllOf; +using ::testing::Field; +using ::testing::Matcher; +using ::testing::SafeMatcherCast; + +Matcher IsFoo(const Foo& foo) { + return AllOf(Field("some_field", &Foo::some_field, foo.some_field), + Field("other_field", &Foo::other_field, foo.other_field), + Field("last_field", &Foo::last_field, foo.last_field)); +} +``` + +### Validating the Value Pointed to by a Pointer Argument + +C++ functions often take pointers as arguments. You can use matchers like +`IsNull()`, `NotNull()`, and other comparison matchers to match a pointer, but +what if you want to make sure the value *pointed to* by the pointer, instead of +the pointer itself, has a certain property? Well, you can use the `Pointee(m)` +matcher. + +`Pointee(m)` matches a pointer if and only if `m` matches the value the pointer +points to. For example: + +```cpp +using ::testing::Ge; +using ::testing::Pointee; +... + EXPECT_CALL(foo, Bar(Pointee(Ge(3)))); +``` + +expects `foo.Bar()` to be called with a pointer that points to a value greater +than or equal to 3. + +One nice thing about `Pointee()` is that it treats a `NULL` pointer as a match +failure, so you can write `Pointee(m)` instead of + +```cpp +using ::testing::AllOf; +using ::testing::NotNull; +using ::testing::Pointee; +... + AllOf(NotNull(), Pointee(m)) +``` + +without worrying that a `NULL` pointer will crash your test. + +Also, did we tell you that `Pointee()` works with both raw pointers **and** +smart pointers (`std::unique_ptr`, `std::shared_ptr`, etc)? + +What if you have a pointer to pointer? You guessed it - you can use nested +`Pointee()` to probe deeper inside the value. For example, +`Pointee(Pointee(Lt(3)))` matches a pointer that points to a pointer that points +to a number less than 3 (what a mouthful...). + +### Defining a Custom Matcher Class {#CustomMatcherClass} + +Most matchers can be simply defined using [the MATCHER* macros](#NewMatchers), +which are terse and flexible, and produce good error messages. However, these +macros are not very explicit about the interfaces they create and are not always +suitable, especially for matchers that will be widely reused. + +For more advanced cases, you may need to define your own matcher class. A custom +matcher allows you to test a specific invariant property of that object. Let's +take a look at how to do so. + +Imagine you have a mock function that takes an object of type `Foo`, which has +an `int bar()` method and an `int baz()` method. You want to constrain that the +argument's `bar()` value plus its `baz()` value is a given number. (This is an +invariant.) Here's how we can write and use a matcher class to do so: + +```cpp +class BarPlusBazEqMatcher { + public: + using is_gtest_matcher = void; + + explicit BarPlusBazEqMatcher(int expected_sum) + : expected_sum_(expected_sum) {} + + bool MatchAndExplain(const Foo& foo, + std::ostream* /* listener */) const { + return (foo.bar() + foo.baz()) == expected_sum_; + } + + void DescribeTo(std::ostream* os) const { + *os << "bar() + baz() equals " << expected_sum_; + } + + void DescribeNegationTo(std::ostream* os) const { + *os << "bar() + baz() does not equal " << expected_sum_; + } + private: + const int expected_sum_; +}; + +::testing::Matcher BarPlusBazEq(int expected_sum) { + return BarPlusBazEqMatcher(expected_sum); +} + +... + Foo foo; + EXPECT_THAT(foo, BarPlusBazEq(5))...; +``` + +### Matching Containers + +Sometimes an STL container (e.g. list, vector, map, ...) is passed to a mock +function and you may want to validate it. Since most STL containers support the +`==` operator, you can write `Eq(expected_container)` or simply +`expected_container` to match a container exactly. + +Sometimes, though, you may want to be more flexible (for example, the first +element must be an exact match, but the second element can be any positive +number, and so on). Also, containers used in tests often have a small number of +elements, and having to define the expected container out-of-line is a bit of a +hassle. + +You can use the `ElementsAre()` or `UnorderedElementsAre()` matcher in such +cases: + +```cpp +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::Gt; +... + MOCK_METHOD(void, Foo, (const vector& numbers), (override)); +... + EXPECT_CALL(mock, Foo(ElementsAre(1, Gt(0), _, 5))); +``` + +The above matcher says that the container must have 4 elements, which must be 1, +greater than 0, anything, and 5 respectively. + +If you instead write: + +```cpp +using ::testing::_; +using ::testing::Gt; +using ::testing::UnorderedElementsAre; +... + MOCK_METHOD(void, Foo, (const vector& numbers), (override)); +... + EXPECT_CALL(mock, Foo(UnorderedElementsAre(1, Gt(0), _, 5))); +``` + +It means that the container must have 4 elements, which (under some permutation) +must be 1, greater than 0, anything, and 5 respectively. + +As an alternative you can place the arguments in a C-style array and use +`ElementsAreArray()` or `UnorderedElementsAreArray()` instead: + +```cpp +using ::testing::ElementsAreArray; +... + // ElementsAreArray accepts an array of element values. + const int expected_vector1[] = {1, 5, 2, 4, ...}; + EXPECT_CALL(mock, Foo(ElementsAreArray(expected_vector1))); + + // Or, an array of element matchers. + Matcher expected_vector2[] = {1, Gt(2), _, 3, ...}; + EXPECT_CALL(mock, Foo(ElementsAreArray(expected_vector2))); +``` + +In case the array needs to be dynamically created (and therefore the array size +cannot be inferred by the compiler), you can give `ElementsAreArray()` an +additional argument to specify the array size: + +```cpp +using ::testing::ElementsAreArray; +... + int* const expected_vector3 = new int[count]; + ... fill expected_vector3 with values ... + EXPECT_CALL(mock, Foo(ElementsAreArray(expected_vector3, count))); +``` + +Use `Pair` when comparing maps or other associative containers. + +{% raw %} + +```cpp +using ::testing::UnorderedElementsAre; +using ::testing::Pair; +... + absl::flat_hash_map m = {{"a", 1}, {"b", 2}, {"c", 3}}; + EXPECT_THAT(m, UnorderedElementsAre( + Pair("a", 1), Pair("b", 2), Pair("c", 3))); +``` + +{% endraw %} + +**Tips:** + +* `ElementsAre*()` can be used to match *any* container that implements the + STL iterator pattern (i.e. it has a `const_iterator` type and supports + `begin()/end()`), not just the ones defined in STL. It will even work with + container types yet to be written - as long as they follows the above + pattern. +* You can use nested `ElementsAre*()` to match nested (multi-dimensional) + containers. +* If the container is passed by pointer instead of by reference, just write + `Pointee(ElementsAre*(...))`. +* The order of elements *matters* for `ElementsAre*()`. If you are using it + with containers whose element order are undefined (such as a + `std::unordered_map`) you should use `UnorderedElementsAre`. + +### Sharing Matchers + +Under the hood, a gMock matcher object consists of a pointer to a ref-counted +implementation object. Copying matchers is allowed and very efficient, as only +the pointer is copied. When the last matcher that references the implementation +object dies, the implementation object will be deleted. + +Therefore, if you have some complex matcher that you want to use again and +again, there is no need to build it every time. Just assign it to a matcher +variable and use that variable repeatedly! For example, + +```cpp +using ::testing::AllOf; +using ::testing::Gt; +using ::testing::Le; +using ::testing::Matcher; +... + Matcher in_range = AllOf(Gt(5), Le(10)); + ... use in_range as a matcher in multiple EXPECT_CALLs ... +``` + +### Matchers must have no side-effects {#PureMatchers} + +{: .callout .warning} +WARNING: gMock does not guarantee when or how many times a matcher will be +invoked. Therefore, all matchers must be *purely functional*: they cannot have +any side effects, and the match result must not depend on anything other than +the matcher's parameters and the value being matched. + +This requirement must be satisfied no matter how a matcher is defined (e.g., if +it is one of the standard matchers, or a custom matcher). In particular, a +matcher can never call a mock function, as that will affect the state of the +mock object and gMock. + +## Setting Expectations + +### Knowing When to Expect {#UseOnCall} + +**`ON_CALL`** is likely the *single most under-utilized construct* in gMock. + +There are basically two constructs for defining the behavior of a mock object: +`ON_CALL` and `EXPECT_CALL`. The difference? `ON_CALL` defines what happens when +a mock method is called, but doesn't imply any expectation on the method +being called. `EXPECT_CALL` not only defines the behavior, but also sets an +expectation that the method will be called with the given arguments, for the +given number of times (and *in the given order* when you specify the order +too). + +Since `EXPECT_CALL` does more, isn't it better than `ON_CALL`? Not really. Every +`EXPECT_CALL` adds a constraint on the behavior of the code under test. Having +more constraints than necessary is *baaad* - even worse than not having enough +constraints. + +This may be counter-intuitive. How could tests that verify more be worse than +tests that verify less? Isn't verification the whole point of tests? + +The answer lies in *what* a test should verify. **A good test verifies the +contract of the code.** If a test over-specifies, it doesn't leave enough +freedom to the implementation. As a result, changing the implementation without +breaking the contract (e.g. refactoring and optimization), which should be +perfectly fine to do, can break such tests. Then you have to spend time fixing +them, only to see them broken again the next time the implementation is changed. + +Keep in mind that one doesn't have to verify more than one property in one test. +In fact, **it's a good style to verify only one thing in one test.** If you do +that, a bug will likely break only one or two tests instead of dozens (which +case would you rather debug?). If you are also in the habit of giving tests +descriptive names that tell what they verify, you can often easily guess what's +wrong just from the test log itself. + +So use `ON_CALL` by default, and only use `EXPECT_CALL` when you actually intend +to verify that the call is made. For example, you may have a bunch of `ON_CALL`s +in your test fixture to set the common mock behavior shared by all tests in the +same group, and write (scarcely) different `EXPECT_CALL`s in different `TEST_F`s +to verify different aspects of the code's behavior. Compared with the style +where each `TEST` has many `EXPECT_CALL`s, this leads to tests that are more +resilient to implementational changes (and thus less likely to require +maintenance) and makes the intent of the tests more obvious (so they are easier +to maintain when you do need to maintain them). + +If you are bothered by the "Uninteresting mock function call" message printed +when a mock method without an `EXPECT_CALL` is called, you may use a `NiceMock` +instead to suppress all such messages for the mock object, or suppress the +message for specific methods by adding `EXPECT_CALL(...).Times(AnyNumber())`. DO +NOT suppress it by blindly adding an `EXPECT_CALL(...)`, or you'll have a test +that's a pain to maintain. + +### Ignoring Uninteresting Calls + +If you are not interested in how a mock method is called, just don't say +anything about it. In this case, if the method is ever called, gMock will +perform its default action to allow the test program to continue. If you are not +happy with the default action taken by gMock, you can override it using +`DefaultValue::Set()` (described [here](#DefaultValue)) or `ON_CALL()`. + +Please note that once you expressed interest in a particular mock method (via +`EXPECT_CALL()`), all invocations to it must match some expectation. If this +function is called but the arguments don't match any `EXPECT_CALL()` statement, +it will be an error. + +### Disallowing Unexpected Calls + +If a mock method shouldn't be called at all, explicitly say so: + +```cpp +using ::testing::_; +... + EXPECT_CALL(foo, Bar(_)) + .Times(0); +``` + +If some calls to the method are allowed, but the rest are not, just list all the +expected calls: + +```cpp +using ::testing::AnyNumber; +using ::testing::Gt; +... + EXPECT_CALL(foo, Bar(5)); + EXPECT_CALL(foo, Bar(Gt(10))) + .Times(AnyNumber()); +``` + +A call to `foo.Bar()` that doesn't match any of the `EXPECT_CALL()` statements +will be an error. + +### Understanding Uninteresting vs Unexpected Calls {#uninteresting-vs-unexpected} + +*Uninteresting* calls and *unexpected* calls are different concepts in gMock. +*Very* different. + +A call `x.Y(...)` is **uninteresting** if there's *not even a single* +`EXPECT_CALL(x, Y(...))` set. In other words, the test isn't interested in the +`x.Y()` method at all, as evident in that the test doesn't care to say anything +about it. + +A call `x.Y(...)` is **unexpected** if there are *some* `EXPECT_CALL(x, +Y(...))`s set, but none of them matches the call. Put another way, the test is +interested in the `x.Y()` method (therefore it explicitly sets some +`EXPECT_CALL` to verify how it's called); however, the verification fails as the +test doesn't expect this particular call to happen. + +**An unexpected call is always an error,** as the code under test doesn't behave +the way the test expects it to behave. + +**By default, an uninteresting call is not an error,** as it violates no +constraint specified by the test. (gMock's philosophy is that saying nothing +means there is no constraint.) However, it leads to a warning, as it *might* +indicate a problem (e.g. the test author might have forgotten to specify a +constraint). + +In gMock, `NiceMock` and `StrictMock` can be used to make a mock class "nice" or +"strict". How does this affect uninteresting calls and unexpected calls? + +A **nice mock** suppresses uninteresting call *warnings*. It is less chatty than +the default mock, but otherwise is the same. If a test fails with a default +mock, it will also fail using a nice mock instead. And vice versa. Don't expect +making a mock nice to change the test's result. + +A **strict mock** turns uninteresting call warnings into errors. So making a +mock strict may change the test's result. + +Let's look at an example: + +```cpp +TEST(...) { + NiceMock mock_registry; + EXPECT_CALL(mock_registry, GetDomainOwner("google.com")) + .WillRepeatedly(Return("Larry Page")); + + // Use mock_registry in code under test. + ... &mock_registry ... +} +``` + +The sole `EXPECT_CALL` here says that all calls to `GetDomainOwner()` must have +`"google.com"` as the argument. If `GetDomainOwner("yahoo.com")` is called, it +will be an unexpected call, and thus an error. *Having a nice mock doesn't +change the severity of an unexpected call.* + +So how do we tell gMock that `GetDomainOwner()` can be called with some other +arguments as well? The standard technique is to add a "catch all" `EXPECT_CALL`: + +```cpp + EXPECT_CALL(mock_registry, GetDomainOwner(_)) + .Times(AnyNumber()); // catches all other calls to this method. + EXPECT_CALL(mock_registry, GetDomainOwner("google.com")) + .WillRepeatedly(Return("Larry Page")); +``` + +Remember that `_` is the wildcard matcher that matches anything. With this, if +`GetDomainOwner("google.com")` is called, it will do what the second +`EXPECT_CALL` says; if it is called with a different argument, it will do what +the first `EXPECT_CALL` says. + +Note that the order of the two `EXPECT_CALL`s is important, as a newer +`EXPECT_CALL` takes precedence over an older one. + +For more on uninteresting calls, nice mocks, and strict mocks, read +["The Nice, the Strict, and the Naggy"](#NiceStrictNaggy). + +### Ignoring Uninteresting Arguments {#ParameterlessExpectations} + +If your test doesn't care about the parameters (it only cares about the number +or order of calls), you can often simply omit the parameter list: + +```cpp + // Expect foo.Bar( ... ) twice with any arguments. + EXPECT_CALL(foo, Bar).Times(2); + + // Delegate to the given method whenever the factory is invoked. + ON_CALL(foo_factory, MakeFoo) + .WillByDefault(&BuildFooForTest); +``` + +This functionality is only available when a method is not overloaded; to prevent +unexpected behavior it is a compilation error to try to set an expectation on a +method where the specific overload is ambiguous. You can work around this by +supplying a [simpler mock interface](#SimplerInterfaces) than the mocked class +provides. + +This pattern is also useful when the arguments are interesting, but match logic +is substantially complex. You can leave the argument list unspecified and use +SaveArg actions to [save the values for later verification](#SaveArgVerify). If +you do that, you can easily differentiate calling the method the wrong number of +times from calling it with the wrong arguments. + +### Expecting Ordered Calls {#OrderedCalls} + +Although an `EXPECT_CALL()` statement defined later takes precedence when gMock +tries to match a function call with an expectation, by default calls don't have +to happen in the order `EXPECT_CALL()` statements are written. For example, if +the arguments match the matchers in the second `EXPECT_CALL()`, but not those in +the first and third, then the second expectation will be used. + +If you would rather have all calls occur in the order of the expectations, put +the `EXPECT_CALL()` statements in a block where you define a variable of type +`InSequence`: + +```cpp +using ::testing::_; +using ::testing::InSequence; + + { + InSequence s; + + EXPECT_CALL(foo, DoThis(5)); + EXPECT_CALL(bar, DoThat(_)) + .Times(2); + EXPECT_CALL(foo, DoThis(6)); + } +``` + +In this example, we expect a call to `foo.DoThis(5)`, followed by two calls to +`bar.DoThat()` where the argument can be anything, which are in turn followed by +a call to `foo.DoThis(6)`. If a call occurred out-of-order, gMock will report an +error. + +### Expecting Partially Ordered Calls {#PartialOrder} + +Sometimes requiring everything to occur in a predetermined order can lead to +brittle tests. For example, we may care about `A` occurring before both `B` and +`C`, but aren't interested in the relative order of `B` and `C`. In this case, +the test should reflect our real intent, instead of being overly constraining. + +gMock allows you to impose an arbitrary DAG (directed acyclic graph) on the +calls. One way to express the DAG is to use the +[`After` clause](reference/mocking.md#EXPECT_CALL.After) of `EXPECT_CALL`. + +Another way is via the `InSequence()` clause (not the same as the `InSequence` +class), which we borrowed from jMock 2. It's less flexible than `After()`, but +more convenient when you have long chains of sequential calls, as it doesn't +require you to come up with different names for the expectations in the chains. +Here's how it works: + +If we view `EXPECT_CALL()` statements as nodes in a graph, and add an edge from +node A to node B wherever A must occur before B, we can get a DAG. We use the +term "sequence" to mean a directed path in this DAG. Now, if we decompose the +DAG into sequences, we just need to know which sequences each `EXPECT_CALL()` +belongs to in order to be able to reconstruct the original DAG. + +So, to specify the partial order on the expectations we need to do two things: +first to define some `Sequence` objects, and then for each `EXPECT_CALL()` say +which `Sequence` objects it is part of. + +Expectations in the same sequence must occur in the order they are written. For +example, + +```cpp +using ::testing::Sequence; +... + Sequence s1, s2; + + EXPECT_CALL(foo, A()) + .InSequence(s1, s2); + EXPECT_CALL(bar, B()) + .InSequence(s1); + EXPECT_CALL(bar, C()) + .InSequence(s2); + EXPECT_CALL(foo, D()) + .InSequence(s2); +``` + +specifies the following DAG (where `s1` is `A -> B`, and `s2` is `A -> C -> D`): + +```text + +---> B + | + A ---| + | + +---> C ---> D +``` + +This means that A must occur before B and C, and C must occur before D. There's +no restriction about the order other than these. + +### Controlling When an Expectation Retires + +When a mock method is called, gMock only considers expectations that are still +active. An expectation is active when created, and becomes inactive (aka +*retires*) when a call that has to occur later has occurred. For example, in + +```cpp +using ::testing::_; +using ::testing::Sequence; +... + Sequence s1, s2; + + EXPECT_CALL(log, Log(WARNING, _, "File too large.")) // #1 + .Times(AnyNumber()) + .InSequence(s1, s2); + EXPECT_CALL(log, Log(WARNING, _, "Data set is empty.")) // #2 + .InSequence(s1); + EXPECT_CALL(log, Log(WARNING, _, "User not found.")) // #3 + .InSequence(s2); +``` + +as soon as either #2 or #3 is matched, #1 will retire. If a warning `"File too +large."` is logged after this, it will be an error. + +Note that an expectation doesn't retire automatically when it's saturated. For +example, + +```cpp +using ::testing::_; +... + EXPECT_CALL(log, Log(WARNING, _, _)); // #1 + EXPECT_CALL(log, Log(WARNING, _, "File too large.")); // #2 +``` + +says that there will be exactly one warning with the message `"File too +large."`. If the second warning contains this message too, #2 will match again +and result in an upper-bound-violated error. + +If this is not what you want, you can ask an expectation to retire as soon as it +becomes saturated: + +```cpp +using ::testing::_; +... + EXPECT_CALL(log, Log(WARNING, _, _)); // #1 + EXPECT_CALL(log, Log(WARNING, _, "File too large.")) // #2 + .RetiresOnSaturation(); +``` + +Here #2 can be used only once, so if you have two warnings with the message +`"File too large."`, the first will match #2 and the second will match #1 - +there will be no error. + +## Using Actions + +### Returning References from Mock Methods + +If a mock function's return type is a reference, you need to use `ReturnRef()` +instead of `Return()` to return a result: + +```cpp +using ::testing::ReturnRef; + +class MockFoo : public Foo { + public: + MOCK_METHOD(Bar&, GetBar, (), (override)); +}; +... + MockFoo foo; + Bar bar; + EXPECT_CALL(foo, GetBar()) + .WillOnce(ReturnRef(bar)); +... +``` + +### Returning Live Values from Mock Methods + +The `Return(x)` action saves a copy of `x` when the action is created, and +always returns the same value whenever it's executed. Sometimes you may want to +instead return the *live* value of `x` (i.e. its value at the time when the +action is *executed*.). Use either `ReturnRef()` or `ReturnPointee()` for this +purpose. + +If the mock function's return type is a reference, you can do it using +`ReturnRef(x)`, as shown in the previous recipe ("Returning References from Mock +Methods"). However, gMock doesn't let you use `ReturnRef()` in a mock function +whose return type is not a reference, as doing that usually indicates a user +error. So, what shall you do? + +Though you may be tempted, DO NOT use `std::ref()`: + +```cpp +using testing::Return; + +class MockFoo : public Foo { + public: + MOCK_METHOD(int, GetValue, (), (override)); +}; +... + int x = 0; + MockFoo foo; + EXPECT_CALL(foo, GetValue()) + .WillRepeatedly(Return(std::ref(x))); // Wrong! + x = 42; + EXPECT_EQ(42, foo.GetValue()); +``` + +Unfortunately, it doesn't work here. The above code will fail with error: + +```text +Value of: foo.GetValue() + Actual: 0 +Expected: 42 +``` + +The reason is that `Return(*value*)` converts `value` to the actual return type +of the mock function at the time when the action is *created*, not when it is +*executed*. (This behavior was chosen for the action to be safe when `value` is +a proxy object that references some temporary objects.) As a result, +`std::ref(x)` is converted to an `int` value (instead of a `const int&`) when +the expectation is set, and `Return(std::ref(x))` will always return 0. + +`ReturnPointee(pointer)` was provided to solve this problem specifically. It +returns the value pointed to by `pointer` at the time the action is *executed*: + +```cpp +using testing::ReturnPointee; +... + int x = 0; + MockFoo foo; + EXPECT_CALL(foo, GetValue()) + .WillRepeatedly(ReturnPointee(&x)); // Note the & here. + x = 42; + EXPECT_EQ(42, foo.GetValue()); // This will succeed now. +``` + +### Combining Actions + +Want to do more than one thing when a function is called? That's fine. `DoAll()` +allows you to do a sequence of actions every time. Only the return value of the +last action in the sequence will be used. + +```cpp +using ::testing::_; +using ::testing::DoAll; + +class MockFoo : public Foo { + public: + MOCK_METHOD(bool, Bar, (int n), (override)); +}; +... + EXPECT_CALL(foo, Bar(_)) + .WillOnce(DoAll(action_1, + action_2, + ... + action_n)); +``` + +### Verifying Complex Arguments {#SaveArgVerify} + +If you want to verify that a method is called with a particular argument but the +match criteria is complex, it can be difficult to distinguish between +cardinality failures (calling the method the wrong number of times) and argument +match failures. Similarly, if you are matching multiple parameters, it may not +be easy to distinguishing which argument failed to match. For example: + +```cpp + // Not ideal: this could fail because of a problem with arg1 or arg2, or maybe + // just the method wasn't called. + EXPECT_CALL(foo, SendValues(_, ElementsAre(1, 4, 4, 7), EqualsProto( ... ))); +``` + +You can instead save the arguments and test them individually: + +```cpp + EXPECT_CALL(foo, SendValues) + .WillOnce(DoAll(SaveArg<1>(&actual_array), SaveArg<2>(&actual_proto))); + ... run the test + EXPECT_THAT(actual_array, ElementsAre(1, 4, 4, 7)); + EXPECT_THAT(actual_proto, EqualsProto( ... )); +``` + +### Mocking Side Effects {#MockingSideEffects} + +Sometimes a method exhibits its effect not via returning a value but via side +effects. For example, it may change some global state or modify an output +argument. To mock side effects, in general you can define your own action by +implementing `::testing::ActionInterface`. + +If all you need to do is to change an output argument, the built-in +`SetArgPointee()` action is convenient: + +```cpp +using ::testing::_; +using ::testing::SetArgPointee; + +class MockMutator : public Mutator { + public: + MOCK_METHOD(void, Mutate, (bool mutate, int* value), (override)); + ... +} +... + MockMutator mutator; + EXPECT_CALL(mutator, Mutate(true, _)) + .WillOnce(SetArgPointee<1>(5)); +``` + +In this example, when `mutator.Mutate()` is called, we will assign 5 to the +`int` variable pointed to by argument #1 (0-based). + +`SetArgPointee()` conveniently makes an internal copy of the value you pass to +it, removing the need to keep the value in scope and alive. The implication +however is that the value must have a copy constructor and assignment operator. + +If the mock method also needs to return a value as well, you can chain +`SetArgPointee()` with `Return()` using `DoAll()`, remembering to put the +`Return()` statement last: + +```cpp +using ::testing::_; +using ::testing::DoAll; +using ::testing::Return; +using ::testing::SetArgPointee; + +class MockMutator : public Mutator { + public: + ... + MOCK_METHOD(bool, MutateInt, (int* value), (override)); +} +... + MockMutator mutator; + EXPECT_CALL(mutator, MutateInt(_)) + .WillOnce(DoAll(SetArgPointee<0>(5), + Return(true))); +``` + +Note, however, that if you use the `ReturnOKWith()` method, it will override the +values provided by `SetArgPointee()` in the response parameters of your function +call. + +If the output argument is an array, use the `SetArrayArgument(first, last)` +action instead. It copies the elements in source range `[first, last)` to the +array pointed to by the `N`-th (0-based) argument: + +```cpp +using ::testing::NotNull; +using ::testing::SetArrayArgument; + +class MockArrayMutator : public ArrayMutator { + public: + MOCK_METHOD(void, Mutate, (int* values, int num_values), (override)); + ... +} +... + MockArrayMutator mutator; + int values[5] = {1, 2, 3, 4, 5}; + EXPECT_CALL(mutator, Mutate(NotNull(), 5)) + .WillOnce(SetArrayArgument<0>(values, values + 5)); +``` + +This also works when the argument is an output iterator: + +```cpp +using ::testing::_; +using ::testing::SetArrayArgument; + +class MockRolodex : public Rolodex { + public: + MOCK_METHOD(void, GetNames, (std::back_insert_iterator>), + (override)); + ... +} +... + MockRolodex rolodex; + vector names = {"George", "John", "Thomas"}; + EXPECT_CALL(rolodex, GetNames(_)) + .WillOnce(SetArrayArgument<0>(names.begin(), names.end())); +``` + +### Changing a Mock Object's Behavior Based on the State + +If you expect a call to change the behavior of a mock object, you can use +`::testing::InSequence` to specify different behaviors before and after the +call: + +```cpp +using ::testing::InSequence; +using ::testing::Return; + +... + { + InSequence seq; + EXPECT_CALL(my_mock, IsDirty()) + .WillRepeatedly(Return(true)); + EXPECT_CALL(my_mock, Flush()); + EXPECT_CALL(my_mock, IsDirty()) + .WillRepeatedly(Return(false)); + } + my_mock.FlushIfDirty(); +``` + +This makes `my_mock.IsDirty()` return `true` before `my_mock.Flush()` is called +and return `false` afterwards. + +If the behavior change is more complex, you can store the effects in a variable +and make a mock method get its return value from that variable: + +```cpp +using ::testing::_; +using ::testing::SaveArg; +using ::testing::Return; + +ACTION_P(ReturnPointee, p) { return *p; } +... + int previous_value = 0; + EXPECT_CALL(my_mock, GetPrevValue) + .WillRepeatedly(ReturnPointee(&previous_value)); + EXPECT_CALL(my_mock, UpdateValue) + .WillRepeatedly(SaveArg<0>(&previous_value)); + my_mock.DoSomethingToUpdateValue(); +``` + +Here `my_mock.GetPrevValue()` will always return the argument of the last +`UpdateValue()` call. + +### Setting the Default Value for a Return Type {#DefaultValue} + +If a mock method's return type is a built-in C++ type or pointer, by default it +will return 0 when invoked. Also, in C++ 11 and above, a mock method whose +return type has a default constructor will return a default-constructed value by +default. You only need to specify an action if this default value doesn't work +for you. + +Sometimes, you may want to change this default value, or you may want to specify +a default value for types gMock doesn't know about. You can do this using the +`::testing::DefaultValue` class template: + +```cpp +using ::testing::DefaultValue; + +class MockFoo : public Foo { + public: + MOCK_METHOD(Bar, CalculateBar, (), (override)); +}; + + +... + Bar default_bar; + // Sets the default return value for type Bar. + DefaultValue::Set(default_bar); + + MockFoo foo; + + // We don't need to specify an action here, as the default + // return value works for us. + EXPECT_CALL(foo, CalculateBar()); + + foo.CalculateBar(); // This should return default_bar. + + // Unsets the default return value. + DefaultValue::Clear(); +``` + +Please note that changing the default value for a type can make your tests hard +to understand. We recommend you to use this feature judiciously. For example, +you may want to make sure the `Set()` and `Clear()` calls are right next to the +code that uses your mock. + +### Setting the Default Actions for a Mock Method + +You've learned how to change the default value of a given type. However, this +may be too coarse for your purpose: perhaps you have two mock methods with the +same return type and you want them to have different behaviors. The `ON_CALL()` +macro allows you to customize your mock's behavior at the method level: + +```cpp +using ::testing::_; +using ::testing::AnyNumber; +using ::testing::Gt; +using ::testing::Return; +... + ON_CALL(foo, Sign(_)) + .WillByDefault(Return(-1)); + ON_CALL(foo, Sign(0)) + .WillByDefault(Return(0)); + ON_CALL(foo, Sign(Gt(0))) + .WillByDefault(Return(1)); + + EXPECT_CALL(foo, Sign(_)) + .Times(AnyNumber()); + + foo.Sign(5); // This should return 1. + foo.Sign(-9); // This should return -1. + foo.Sign(0); // This should return 0. +``` + +As you may have guessed, when there are more than one `ON_CALL()` statements, +the newer ones in the order take precedence over the older ones. In other words, +the **last** one that matches the function arguments will be used. This matching +order allows you to set up the common behavior in a mock object's constructor or +the test fixture's set-up phase and specialize the mock's behavior later. + +Note that both `ON_CALL` and `EXPECT_CALL` have the same "later statements take +precedence" rule, but they don't interact. That is, `EXPECT_CALL`s have their +own precedence order distinct from the `ON_CALL` precedence order. + +### Using Functions/Methods/Functors/Lambdas as Actions {#FunctionsAsActions} + +If the built-in actions don't suit you, you can use an existing callable +(function, `std::function`, method, functor, lambda) as an action. + +```cpp +using ::testing::_; using ::testing::Invoke; + +class MockFoo : public Foo { + public: + MOCK_METHOD(int, Sum, (int x, int y), (override)); + MOCK_METHOD(bool, ComplexJob, (int x), (override)); +}; + +int CalculateSum(int x, int y) { return x + y; } +int Sum3(int x, int y, int z) { return x + y + z; } + +class Helper { + public: + bool ComplexJob(int x); +}; + +... + MockFoo foo; + Helper helper; + EXPECT_CALL(foo, Sum(_, _)) + .WillOnce(&CalculateSum) + .WillRepeatedly(Invoke(NewPermanentCallback(Sum3, 1))); + EXPECT_CALL(foo, ComplexJob(_)) + .WillOnce(Invoke(&helper, &Helper::ComplexJob)) + .WillOnce([] { return true; }) + .WillRepeatedly([](int x) { return x > 0; }); + + foo.Sum(5, 6); // Invokes CalculateSum(5, 6). + foo.Sum(2, 3); // Invokes Sum3(1, 2, 3). + foo.ComplexJob(10); // Invokes helper.ComplexJob(10). + foo.ComplexJob(-1); // Invokes the inline lambda. +``` + +The only requirement is that the type of the function, etc must be *compatible* +with the signature of the mock function, meaning that the latter's arguments (if +it takes any) can be implicitly converted to the corresponding arguments of the +former, and the former's return type can be implicitly converted to that of the +latter. So, you can invoke something whose type is *not* exactly the same as the +mock function, as long as it's safe to do so - nice, huh? + +Note that: + +* The action takes ownership of the callback and will delete it when the + action itself is destructed. +* If the type of a callback is derived from a base callback type `C`, you need + to implicitly cast it to `C` to resolve the overloading, e.g. + + ```cpp + using ::testing::Invoke; + ... + ResultCallback* is_ok = ...; + ... Invoke(is_ok) ...; // This works. + + BlockingClosure* done = new BlockingClosure; + ... Invoke(implicit_cast(done)) ...; // The cast is necessary. + ``` + +### Using Functions with Extra Info as Actions + +The function or functor you call using `Invoke()` must have the same number of +arguments as the mock function you use it for. Sometimes you may have a function +that takes more arguments, and you are willing to pass in the extra arguments +yourself to fill the gap. You can do this in gMock using callbacks with +pre-bound arguments. Here's an example: + +```cpp +using ::testing::Invoke; + +class MockFoo : public Foo { + public: + MOCK_METHOD(char, DoThis, (int n), (override)); +}; + +char SignOfSum(int x, int y) { + const int sum = x + y; + return (sum > 0) ? '+' : (sum < 0) ? '-' : '0'; +} + +TEST_F(FooTest, Test) { + MockFoo foo; + + EXPECT_CALL(foo, DoThis(2)) + .WillOnce(Invoke(NewPermanentCallback(SignOfSum, 5))); + EXPECT_EQ('+', foo.DoThis(2)); // Invokes SignOfSum(5, 2). +} +``` + +### Invoking a Function/Method/Functor/Lambda/Callback Without Arguments + +`Invoke()` passes the mock function's arguments to the function, etc being +invoked such that the callee has the full context of the call to work with. If +the invoked function is not interested in some or all of the arguments, it can +simply ignore them. + +Yet, a common pattern is that a test author wants to invoke a function without +the arguments of the mock function. She could do that using a wrapper function +that throws away the arguments before invoking an underlining nullary function. +Needless to say, this can be tedious and obscures the intent of the test. + +There are two solutions to this problem. First, you can pass any callable of +zero args as an action. Alternatively, use `InvokeWithoutArgs()`, which is like +`Invoke()` except that it doesn't pass the mock function's arguments to the +callee. Here's an example of each: + +```cpp +using ::testing::_; +using ::testing::InvokeWithoutArgs; + +class MockFoo : public Foo { + public: + MOCK_METHOD(bool, ComplexJob, (int n), (override)); +}; + +bool Job1() { ... } +bool Job2(int n, char c) { ... } + +... + MockFoo foo; + EXPECT_CALL(foo, ComplexJob(_)) + .WillOnce([] { Job1(); }); + .WillOnce(InvokeWithoutArgs(NewPermanentCallback(Job2, 5, 'a'))); + + foo.ComplexJob(10); // Invokes Job1(). + foo.ComplexJob(20); // Invokes Job2(5, 'a'). +``` + +Note that: + +* The action takes ownership of the callback and will delete it when the + action itself is destructed. +* If the type of a callback is derived from a base callback type `C`, you need + to implicitly cast it to `C` to resolve the overloading, e.g. + + ```cpp + using ::testing::InvokeWithoutArgs; + ... + ResultCallback* is_ok = ...; + ... InvokeWithoutArgs(is_ok) ...; // This works. + + BlockingClosure* done = ...; + ... InvokeWithoutArgs(implicit_cast(done)) ...; + // The cast is necessary. + ``` + +### Invoking an Argument of the Mock Function + +Sometimes a mock function will receive a function pointer, a functor (in other +words, a "callable") as an argument, e.g. + +```cpp +class MockFoo : public Foo { + public: + MOCK_METHOD(bool, DoThis, (int n, (ResultCallback1* callback)), + (override)); +}; +``` + +and you may want to invoke this callable argument: + +```cpp +using ::testing::_; +... + MockFoo foo; + EXPECT_CALL(foo, DoThis(_, _)) + .WillOnce(...); + // Will execute callback->Run(5), where callback is the + // second argument DoThis() receives. +``` + +{: .callout .note} +NOTE: The section below is legacy documentation from before C++ had lambdas: + +Arghh, you need to refer to a mock function argument but C++ has no lambda +(yet), so you have to define your own action. :-( Or do you really? + +Well, gMock has an action to solve *exactly* this problem: + +```cpp +InvokeArgument(arg_1, arg_2, ..., arg_m) +``` + +will invoke the `N`-th (0-based) argument the mock function receives, with +`arg_1`, `arg_2`, ..., and `arg_m`. No matter if the argument is a function +pointer, a functor, or a callback. gMock handles them all. + +With that, you could write: + +```cpp +using ::testing::_; +using ::testing::InvokeArgument; +... + EXPECT_CALL(foo, DoThis(_, _)) + .WillOnce(InvokeArgument<1>(5)); + // Will execute callback->Run(5), where callback is the + // second argument DoThis() receives. +``` + +What if the callable takes an argument by reference? No problem - just wrap it +inside `std::ref()`: + +```cpp + ... + MOCK_METHOD(bool, Bar, + ((ResultCallback2* callback)), + (override)); + ... + using ::testing::_; + using ::testing::InvokeArgument; + ... + MockFoo foo; + Helper helper; + ... + EXPECT_CALL(foo, Bar(_)) + .WillOnce(InvokeArgument<0>(5, std::ref(helper))); + // std::ref(helper) guarantees that a reference to helper, not a copy of + // it, will be passed to the callback. +``` + +What if the callable takes an argument by reference and we do **not** wrap the +argument in `std::ref()`? Then `InvokeArgument()` will *make a copy* of the +argument, and pass a *reference to the copy*, instead of a reference to the +original value, to the callable. This is especially handy when the argument is a +temporary value: + +```cpp + ... + MOCK_METHOD(bool, DoThat, (bool (*f)(const double& x, const string& s)), + (override)); + ... + using ::testing::_; + using ::testing::InvokeArgument; + ... + MockFoo foo; + ... + EXPECT_CALL(foo, DoThat(_)) + .WillOnce(InvokeArgument<0>(5.0, string("Hi"))); + // Will execute (*f)(5.0, string("Hi")), where f is the function pointer + // DoThat() receives. Note that the values 5.0 and string("Hi") are + // temporary and dead once the EXPECT_CALL() statement finishes. Yet + // it's fine to perform this action later, since a copy of the values + // are kept inside the InvokeArgument action. +``` + +### Ignoring an Action's Result + +Sometimes you have an action that returns *something*, but you need an action +that returns `void` (perhaps you want to use it in a mock function that returns +`void`, or perhaps it needs to be used in `DoAll()` and it's not the last in the +list). `IgnoreResult()` lets you do that. For example: + +```cpp +using ::testing::_; +using ::testing::DoAll; +using ::testing::IgnoreResult; +using ::testing::Return; + +int Process(const MyData& data); +string DoSomething(); + +class MockFoo : public Foo { + public: + MOCK_METHOD(void, Abc, (const MyData& data), (override)); + MOCK_METHOD(bool, Xyz, (), (override)); +}; + + ... + MockFoo foo; + EXPECT_CALL(foo, Abc(_)) + // .WillOnce(Invoke(Process)); + // The above line won't compile as Process() returns int but Abc() needs + // to return void. + .WillOnce(IgnoreResult(Process)); + EXPECT_CALL(foo, Xyz()) + .WillOnce(DoAll(IgnoreResult(DoSomething), + // Ignores the string DoSomething() returns. + Return(true))); +``` + +Note that you **cannot** use `IgnoreResult()` on an action that already returns +`void`. Doing so will lead to ugly compiler errors. + +### Selecting an Action's Arguments {#SelectingArgs} + +Say you have a mock function `Foo()` that takes seven arguments, and you have a +custom action that you want to invoke when `Foo()` is called. Trouble is, the +custom action only wants three arguments: + +```cpp +using ::testing::_; +using ::testing::Invoke; +... + MOCK_METHOD(bool, Foo, + (bool visible, const string& name, int x, int y, + (const map>), double& weight, double min_weight, + double max_wight)); +... +bool IsVisibleInQuadrant1(bool visible, int x, int y) { + return visible && x >= 0 && y >= 0; +} +... + EXPECT_CALL(mock, Foo) + .WillOnce(Invoke(IsVisibleInQuadrant1)); // Uh, won't compile. :-( +``` + +To please the compiler God, you need to define an "adaptor" that has the same +signature as `Foo()` and calls the custom action with the right arguments: + +```cpp +using ::testing::_; +using ::testing::Invoke; +... +bool MyIsVisibleInQuadrant1(bool visible, const string& name, int x, int y, + const map, double>& weight, + double min_weight, double max_wight) { + return IsVisibleInQuadrant1(visible, x, y); +} +... + EXPECT_CALL(mock, Foo) + .WillOnce(Invoke(MyIsVisibleInQuadrant1)); // Now it works. +``` + +But isn't this awkward? + +gMock provides a generic *action adaptor*, so you can spend your time minding +more important business than writing your own adaptors. Here's the syntax: + +```cpp +WithArgs(action) +``` + +creates an action that passes the arguments of the mock function at the given +indices (0-based) to the inner `action` and performs it. Using `WithArgs`, our +original example can be written as: + +```cpp +using ::testing::_; +using ::testing::Invoke; +using ::testing::WithArgs; +... + EXPECT_CALL(mock, Foo) + .WillOnce(WithArgs<0, 2, 3>(Invoke(IsVisibleInQuadrant1))); // No need to define your own adaptor. +``` + +For better readability, gMock also gives you: + +* `WithoutArgs(action)` when the inner `action` takes *no* argument, and +* `WithArg(action)` (no `s` after `Arg`) when the inner `action` takes + *one* argument. + +As you may have realized, `InvokeWithoutArgs(...)` is just syntactic sugar for +`WithoutArgs(Invoke(...))`. + +Here are more tips: + +* The inner action used in `WithArgs` and friends does not have to be + `Invoke()` -- it can be anything. +* You can repeat an argument in the argument list if necessary, e.g. + `WithArgs<2, 3, 3, 5>(...)`. +* You can change the order of the arguments, e.g. `WithArgs<3, 2, 1>(...)`. +* The types of the selected arguments do *not* have to match the signature of + the inner action exactly. It works as long as they can be implicitly + converted to the corresponding arguments of the inner action. For example, + if the 4-th argument of the mock function is an `int` and `my_action` takes + a `double`, `WithArg<4>(my_action)` will work. + +### Ignoring Arguments in Action Functions + +The [selecting-an-action's-arguments](#SelectingArgs) recipe showed us one way +to make a mock function and an action with incompatible argument lists fit +together. The downside is that wrapping the action in `WithArgs<...>()` can get +tedious for people writing the tests. + +If you are defining a function (or method, functor, lambda, callback) to be used +with `Invoke*()`, and you are not interested in some of its arguments, an +alternative to `WithArgs` is to declare the uninteresting arguments as `Unused`. +This makes the definition less cluttered and less fragile in case the types of +the uninteresting arguments change. It could also increase the chance the action +function can be reused. For example, given + +```cpp + public: + MOCK_METHOD(double, Foo, double(const string& label, double x, double y), + (override)); + MOCK_METHOD(double, Bar, (int index, double x, double y), (override)); +``` + +instead of + +```cpp +using ::testing::_; +using ::testing::Invoke; + +double DistanceToOriginWithLabel(const string& label, double x, double y) { + return sqrt(x*x + y*y); +} +double DistanceToOriginWithIndex(int index, double x, double y) { + return sqrt(x*x + y*y); +} +... + EXPECT_CALL(mock, Foo("abc", _, _)) + .WillOnce(Invoke(DistanceToOriginWithLabel)); + EXPECT_CALL(mock, Bar(5, _, _)) + .WillOnce(Invoke(DistanceToOriginWithIndex)); +``` + +you could write + +```cpp +using ::testing::_; +using ::testing::Invoke; +using ::testing::Unused; + +double DistanceToOrigin(Unused, double x, double y) { + return sqrt(x*x + y*y); +} +... + EXPECT_CALL(mock, Foo("abc", _, _)) + .WillOnce(Invoke(DistanceToOrigin)); + EXPECT_CALL(mock, Bar(5, _, _)) + .WillOnce(Invoke(DistanceToOrigin)); +``` + +### Sharing Actions + +Just like matchers, a gMock action object consists of a pointer to a ref-counted +implementation object. Therefore copying actions is also allowed and very +efficient. When the last action that references the implementation object dies, +the implementation object will be deleted. + +If you have some complex action that you want to use again and again, you may +not have to build it from scratch every time. If the action doesn't have an +internal state (i.e. if it always does the same thing no matter how many times +it has been called), you can assign it to an action variable and use that +variable repeatedly. For example: + +```cpp +using ::testing::Action; +using ::testing::DoAll; +using ::testing::Return; +using ::testing::SetArgPointee; +... + Action set_flag = DoAll(SetArgPointee<0>(5), + Return(true)); + ... use set_flag in .WillOnce() and .WillRepeatedly() ... +``` + +However, if the action has its own state, you may be surprised if you share the +action object. Suppose you have an action factory `IncrementCounter(init)` which +creates an action that increments and returns a counter whose initial value is +`init`, using two actions created from the same expression and using a shared +action will exhibit different behaviors. Example: + +```cpp + EXPECT_CALL(foo, DoThis()) + .WillRepeatedly(IncrementCounter(0)); + EXPECT_CALL(foo, DoThat()) + .WillRepeatedly(IncrementCounter(0)); + foo.DoThis(); // Returns 1. + foo.DoThis(); // Returns 2. + foo.DoThat(); // Returns 1 - Blah() uses a different + // counter than Bar()'s. +``` + +versus + +```cpp +using ::testing::Action; +... + Action increment = IncrementCounter(0); + EXPECT_CALL(foo, DoThis()) + .WillRepeatedly(increment); + EXPECT_CALL(foo, DoThat()) + .WillRepeatedly(increment); + foo.DoThis(); // Returns 1. + foo.DoThis(); // Returns 2. + foo.DoThat(); // Returns 3 - the counter is shared. +``` + +### Testing Asynchronous Behavior + +One oft-encountered problem with gMock is that it can be hard to test +asynchronous behavior. Suppose you had a `EventQueue` class that you wanted to +test, and you created a separate `EventDispatcher` interface so that you could +easily mock it out. However, the implementation of the class fired all the +events on a background thread, which made test timings difficult. You could just +insert `sleep()` statements and hope for the best, but that makes your test +behavior nondeterministic. A better way is to use gMock actions and +`Notification` objects to force your asynchronous test to behave synchronously. + +```cpp +class MockEventDispatcher : public EventDispatcher { + MOCK_METHOD(bool, DispatchEvent, (int32), (override)); +}; + +TEST(EventQueueTest, EnqueueEventTest) { + MockEventDispatcher mock_event_dispatcher; + EventQueue event_queue(&mock_event_dispatcher); + + const int32 kEventId = 321; + absl::Notification done; + EXPECT_CALL(mock_event_dispatcher, DispatchEvent(kEventId)) + .WillOnce([&done] { done.Notify(); }); + + event_queue.EnqueueEvent(kEventId); + done.WaitForNotification(); +} +``` + +In the example above, we set our normal gMock expectations, but then add an +additional action to notify the `Notification` object. Now we can just call +`Notification::WaitForNotification()` in the main thread to wait for the +asynchronous call to finish. After that, our test suite is complete and we can +safely exit. + +{: .callout .note} +Note: this example has a downside: namely, if the expectation is not satisfied, +our test will run forever. It will eventually time-out and fail, but it will +take longer and be slightly harder to debug. To alleviate this problem, you can +use `WaitForNotificationWithTimeout(ms)` instead of `WaitForNotification()`. + +## Misc Recipes on Using gMock + +### Mocking Methods That Use Move-Only Types + +C++11 introduced *move-only types*. A move-only-typed value can be moved from +one object to another, but cannot be copied. `std::unique_ptr` is probably +the most commonly used move-only type. + +Mocking a method that takes and/or returns move-only types presents some +challenges, but nothing insurmountable. This recipe shows you how you can do it. +Note that the support for move-only method arguments was only introduced to +gMock in April 2017; in older code, you may find more complex +[workarounds](#LegacyMoveOnly) for lack of this feature. + +Let’s say we are working on a fictional project that lets one post and share +snippets called “buzzes”. Your code uses these types: + +```cpp +enum class AccessLevel { kInternal, kPublic }; + +class Buzz { + public: + explicit Buzz(AccessLevel access) { ... } + ... +}; + +class Buzzer { + public: + virtual ~Buzzer() {} + virtual std::unique_ptr MakeBuzz(StringPiece text) = 0; + virtual bool ShareBuzz(std::unique_ptr buzz, int64_t timestamp) = 0; + ... +}; +``` + +A `Buzz` object represents a snippet being posted. A class that implements the +`Buzzer` interface is capable of creating and sharing `Buzz`es. Methods in +`Buzzer` may return a `unique_ptr` or take a `unique_ptr`. Now we +need to mock `Buzzer` in our tests. + +To mock a method that accepts or returns move-only types, you just use the +familiar `MOCK_METHOD` syntax as usual: + +```cpp +class MockBuzzer : public Buzzer { + public: + MOCK_METHOD(std::unique_ptr, MakeBuzz, (StringPiece text), (override)); + MOCK_METHOD(bool, ShareBuzz, (std::unique_ptr buzz, int64_t timestamp), + (override)); +}; +``` + +Now that we have the mock class defined, we can use it in tests. In the +following code examples, we assume that we have defined a `MockBuzzer` object +named `mock_buzzer_`: + +```cpp + MockBuzzer mock_buzzer_; +``` + +First let’s see how we can set expectations on the `MakeBuzz()` method, which +returns a `unique_ptr`. + +As usual, if you set an expectation without an action (i.e. the `.WillOnce()` or +`.WillRepeatedly()` clause), when that expectation fires, the default action for +that method will be taken. Since `unique_ptr<>` has a default constructor that +returns a null `unique_ptr`, that’s what you’ll get if you don’t specify an +action: + +```cpp + // Use the default action. + EXPECT_CALL(mock_buzzer_, MakeBuzz("hello")); + + // Triggers the previous EXPECT_CALL. + EXPECT_EQ(nullptr, mock_buzzer_.MakeBuzz("hello")); +``` + +If you are not happy with the default action, you can tweak it as usual; see +[Setting Default Actions](#OnCall). + +If you just need to return a pre-defined move-only value, you can use the +`Return(ByMove(...))` action: + +```cpp + // When this fires, the unique_ptr<> specified by ByMove(...) will + // be returned. + EXPECT_CALL(mock_buzzer_, MakeBuzz("world")) + .WillOnce(Return(ByMove(std::make_unique(AccessLevel::kInternal)))); + + EXPECT_NE(nullptr, mock_buzzer_.MakeBuzz("world")); +``` + +Note that `ByMove()` is essential here - if you drop it, the code won’t compile. + +Quiz time! What do you think will happen if a `Return(ByMove(...))` action is +performed more than once (e.g. you write `... +.WillRepeatedly(Return(ByMove(...)));`)? Come think of it, after the first time +the action runs, the source value will be consumed (since it’s a move-only +value), so the next time around, there’s no value to move from -- you’ll get a +run-time error that `Return(ByMove(...))` can only be run once. + +If you need your mock method to do more than just moving a pre-defined value, +remember that you can always use a lambda or a callable object, which can do +pretty much anything you want: + +```cpp + EXPECT_CALL(mock_buzzer_, MakeBuzz("x")) + .WillRepeatedly([](StringPiece text) { + return std::make_unique(AccessLevel::kInternal); + }); + + EXPECT_NE(nullptr, mock_buzzer_.MakeBuzz("x")); + EXPECT_NE(nullptr, mock_buzzer_.MakeBuzz("x")); +``` + +Every time this `EXPECT_CALL` fires, a new `unique_ptr` will be created +and returned. You cannot do this with `Return(ByMove(...))`. + +That covers returning move-only values; but how do we work with methods +accepting move-only arguments? The answer is that they work normally, although +some actions will not compile when any of method's arguments are move-only. You +can always use `Return`, or a [lambda or functor](#FunctionsAsActions): + +```cpp + using ::testing::Unused; + + EXPECT_CALL(mock_buzzer_, ShareBuzz(NotNull(), _)).WillOnce(Return(true)); + EXPECT_TRUE(mock_buzzer_.ShareBuzz(std::make_unique(AccessLevel::kInternal)), + 0); + + EXPECT_CALL(mock_buzzer_, ShareBuzz(_, _)).WillOnce( + [](std::unique_ptr buzz, Unused) { return buzz != nullptr; }); + EXPECT_FALSE(mock_buzzer_.ShareBuzz(nullptr, 0)); +``` + +Many built-in actions (`WithArgs`, `WithoutArgs`,`DeleteArg`, `SaveArg`, ...) +could in principle support move-only arguments, but the support for this is not +implemented yet. If this is blocking you, please file a bug. + +A few actions (e.g. `DoAll`) copy their arguments internally, so they can never +work with non-copyable objects; you'll have to use functors instead. + +#### Legacy workarounds for move-only types {#LegacyMoveOnly} + +Support for move-only function arguments was only introduced to gMock in April +of 2017. In older code, you may encounter the following workaround for the lack +of this feature (it is no longer necessary - we're including it just for +reference): + +```cpp +class MockBuzzer : public Buzzer { + public: + MOCK_METHOD(bool, DoShareBuzz, (Buzz* buzz, Time timestamp)); + bool ShareBuzz(std::unique_ptr buzz, Time timestamp) override { + return DoShareBuzz(buzz.get(), timestamp); + } +}; +``` + +The trick is to delegate the `ShareBuzz()` method to a mock method (let’s call +it `DoShareBuzz()`) that does not take move-only parameters. Then, instead of +setting expectations on `ShareBuzz()`, you set them on the `DoShareBuzz()` mock +method: + +```cpp + MockBuzzer mock_buzzer_; + EXPECT_CALL(mock_buzzer_, DoShareBuzz(NotNull(), _)); + + // When one calls ShareBuzz() on the MockBuzzer like this, the call is + // forwarded to DoShareBuzz(), which is mocked. Therefore this statement + // will trigger the above EXPECT_CALL. + mock_buzzer_.ShareBuzz(std::make_unique(AccessLevel::kInternal), 0); +``` + +### Making the Compilation Faster + +Believe it or not, the *vast majority* of the time spent on compiling a mock +class is in generating its constructor and destructor, as they perform +non-trivial tasks (e.g. verification of the expectations). What's more, mock +methods with different signatures have different types and thus their +constructors/destructors need to be generated by the compiler separately. As a +result, if you mock many different types of methods, compiling your mock class +can get really slow. + +If you are experiencing slow compilation, you can move the definition of your +mock class' constructor and destructor out of the class body and into a `.cc` +file. This way, even if you `#include` your mock class in N files, the compiler +only needs to generate its constructor and destructor once, resulting in a much +faster compilation. + +Let's illustrate the idea using an example. Here's the definition of a mock +class before applying this recipe: + +```cpp +// File mock_foo.h. +... +class MockFoo : public Foo { + public: + // Since we don't declare the constructor or the destructor, + // the compiler will generate them in every translation unit + // where this mock class is used. + + MOCK_METHOD(int, DoThis, (), (override)); + MOCK_METHOD(bool, DoThat, (const char* str), (override)); + ... more mock methods ... +}; +``` + +After the change, it would look like: + +```cpp +// File mock_foo.h. +... +class MockFoo : public Foo { + public: + // The constructor and destructor are declared, but not defined, here. + MockFoo(); + virtual ~MockFoo(); + + MOCK_METHOD(int, DoThis, (), (override)); + MOCK_METHOD(bool, DoThat, (const char* str), (override)); + ... more mock methods ... +}; +``` + +and + +```cpp +// File mock_foo.cc. +#include "path/to/mock_foo.h" + +// The definitions may appear trivial, but the functions actually do a +// lot of things through the constructors/destructors of the member +// variables used to implement the mock methods. +MockFoo::MockFoo() {} +MockFoo::~MockFoo() {} +``` + +### Forcing a Verification + +When it's being destroyed, your friendly mock object will automatically verify +that all expectations on it have been satisfied, and will generate googletest +failures if not. This is convenient as it leaves you with one less thing to +worry about. That is, unless you are not sure if your mock object will be +destroyed. + +How could it be that your mock object won't eventually be destroyed? Well, it +might be created on the heap and owned by the code you are testing. Suppose +there's a bug in that code and it doesn't delete the mock object properly - you +could end up with a passing test when there's actually a bug. + +Using a heap checker is a good idea and can alleviate the concern, but its +implementation is not 100% reliable. So, sometimes you do want to *force* gMock +to verify a mock object before it is (hopefully) destructed. You can do this +with `Mock::VerifyAndClearExpectations(&mock_object)`: + +```cpp +TEST(MyServerTest, ProcessesRequest) { + using ::testing::Mock; + + MockFoo* const foo = new MockFoo; + EXPECT_CALL(*foo, ...)...; + // ... other expectations ... + + // server now owns foo. + MyServer server(foo); + server.ProcessRequest(...); + + // In case that server's destructor will forget to delete foo, + // this will verify the expectations anyway. + Mock::VerifyAndClearExpectations(foo); +} // server is destroyed when it goes out of scope here. +``` + +{: .callout .tip} +**Tip:** The `Mock::VerifyAndClearExpectations()` function returns a `bool` to +indicate whether the verification was successful (`true` for yes), so you can +wrap that function call inside a `ASSERT_TRUE()` if there is no point going +further when the verification has failed. + +Do not set new expectations after verifying and clearing a mock after its use. +Setting expectations after code that exercises the mock has undefined behavior. +See [Using Mocks in Tests](gmock_for_dummies.md#using-mocks-in-tests) for more +information. + +### Using Checkpoints {#UsingCheckPoints} + +Sometimes you might want to test a mock object's behavior in phases whose sizes +are each manageable, or you might want to set more detailed expectations about +which API calls invoke which mock functions. + +A technique you can use is to put the expectations in a sequence and insert +calls to a dummy "checkpoint" function at specific places. Then you can verify +that the mock function calls do happen at the right time. For example, if you +are exercising the code: + +```cpp + Foo(1); + Foo(2); + Foo(3); +``` + +and want to verify that `Foo(1)` and `Foo(3)` both invoke `mock.Bar("a")`, but +`Foo(2)` doesn't invoke anything, you can write: + +```cpp +using ::testing::MockFunction; + +TEST(FooTest, InvokesBarCorrectly) { + MyMock mock; + // Class MockFunction has exactly one mock method. It is named + // Call() and has type F. + MockFunction check; + { + InSequence s; + + EXPECT_CALL(mock, Bar("a")); + EXPECT_CALL(check, Call("1")); + EXPECT_CALL(check, Call("2")); + EXPECT_CALL(mock, Bar("a")); + } + Foo(1); + check.Call("1"); + Foo(2); + check.Call("2"); + Foo(3); +} +``` + +The expectation spec says that the first `Bar("a")` call must happen before +checkpoint "1", the second `Bar("a")` call must happen after checkpoint "2", and +nothing should happen between the two checkpoints. The explicit checkpoints make +it clear which `Bar("a")` is called by which call to `Foo()`. + +### Mocking Destructors + +Sometimes you want to make sure a mock object is destructed at the right time, +e.g. after `bar->A()` is called but before `bar->B()` is called. We already know +that you can specify constraints on the [order](#OrderedCalls) of mock function +calls, so all we need to do is to mock the destructor of the mock function. + +This sounds simple, except for one problem: a destructor is a special function +with special syntax and special semantics, and the `MOCK_METHOD` macro doesn't +work for it: + +```cpp +MOCK_METHOD(void, ~MockFoo, ()); // Won't compile! +``` + +The good news is that you can use a simple pattern to achieve the same effect. +First, add a mock function `Die()` to your mock class and call it in the +destructor, like this: + +```cpp +class MockFoo : public Foo { + ... + // Add the following two lines to the mock class. + MOCK_METHOD(void, Die, ()); + ~MockFoo() override { Die(); } +}; +``` + +(If the name `Die()` clashes with an existing symbol, choose another name.) Now, +we have translated the problem of testing when a `MockFoo` object dies to +testing when its `Die()` method is called: + +```cpp + MockFoo* foo = new MockFoo; + MockBar* bar = new MockBar; + ... + { + InSequence s; + + // Expects *foo to die after bar->A() and before bar->B(). + EXPECT_CALL(*bar, A()); + EXPECT_CALL(*foo, Die()); + EXPECT_CALL(*bar, B()); + } +``` + +And that's that. + +### Using gMock and Threads {#UsingThreads} + +In a **unit** test, it's best if you could isolate and test a piece of code in a +single-threaded context. That avoids race conditions and dead locks, and makes +debugging your test much easier. + +Yet most programs are multi-threaded, and sometimes to test something we need to +pound on it from more than one thread. gMock works for this purpose too. + +Remember the steps for using a mock: + +1. Create a mock object `foo`. +2. Set its default actions and expectations using `ON_CALL()` and + `EXPECT_CALL()`. +3. The code under test calls methods of `foo`. +4. Optionally, verify and reset the mock. +5. Destroy the mock yourself, or let the code under test destroy it. The + destructor will automatically verify it. + +If you follow the following simple rules, your mocks and threads can live +happily together: + +* Execute your *test code* (as opposed to the code being tested) in *one* + thread. This makes your test easy to follow. +* Obviously, you can do step #1 without locking. +* When doing step #2 and #5, make sure no other thread is accessing `foo`. + Obvious too, huh? +* #3 and #4 can be done either in one thread or in multiple threads - anyway + you want. gMock takes care of the locking, so you don't have to do any - + unless required by your test logic. + +If you violate the rules (for example, if you set expectations on a mock while +another thread is calling its methods), you get undefined behavior. That's not +fun, so don't do it. + +gMock guarantees that the action for a mock function is done in the same thread +that called the mock function. For example, in + +```cpp + EXPECT_CALL(mock, Foo(1)) + .WillOnce(action1); + EXPECT_CALL(mock, Foo(2)) + .WillOnce(action2); +``` + +if `Foo(1)` is called in thread 1 and `Foo(2)` is called in thread 2, gMock will +execute `action1` in thread 1 and `action2` in thread 2. + +gMock does *not* impose a sequence on actions performed in different threads +(doing so may create deadlocks as the actions may need to cooperate). This means +that the execution of `action1` and `action2` in the above example *may* +interleave. If this is a problem, you should add proper synchronization logic to +`action1` and `action2` to make the test thread-safe. + +Also, remember that `DefaultValue` is a global resource that potentially +affects *all* living mock objects in your program. Naturally, you won't want to +mess with it from multiple threads or when there still are mocks in action. + +### Controlling How Much Information gMock Prints + +When gMock sees something that has the potential of being an error (e.g. a mock +function with no expectation is called, a.k.a. an uninteresting call, which is +allowed but perhaps you forgot to explicitly ban the call), it prints some +warning messages, including the arguments of the function, the return value, and +the stack trace. Hopefully this will remind you to take a look and see if there +is indeed a problem. + +Sometimes you are confident that your tests are correct and may not appreciate +such friendly messages. Some other times, you are debugging your tests or +learning about the behavior of the code you are testing, and wish you could +observe every mock call that happens (including argument values, the return +value, and the stack trace). Clearly, one size doesn't fit all. + +You can control how much gMock tells you using the `--gmock_verbose=LEVEL` +command-line flag, where `LEVEL` is a string with three possible values: + +* `info`: gMock will print all informational messages, warnings, and errors + (most verbose). At this setting, gMock will also log any calls to the + `ON_CALL/EXPECT_CALL` macros. It will include a stack trace in + "uninteresting call" warnings. +* `warning`: gMock will print both warnings and errors (less verbose); it will + omit the stack traces in "uninteresting call" warnings. This is the default. +* `error`: gMock will print errors only (least verbose). + +Alternatively, you can adjust the value of that flag from within your tests like +so: + +```cpp + ::testing::FLAGS_gmock_verbose = "error"; +``` + +If you find gMock printing too many stack frames with its informational or +warning messages, remember that you can control their amount with the +`--gtest_stack_trace_depth=max_depth` flag. + +Now, judiciously use the right flag to enable gMock serve you better! + +### Gaining Super Vision into Mock Calls + +You have a test using gMock. It fails: gMock tells you some expectations aren't +satisfied. However, you aren't sure why: Is there a typo somewhere in the +matchers? Did you mess up the order of the `EXPECT_CALL`s? Or is the code under +test doing something wrong? How can you find out the cause? + +Won't it be nice if you have X-ray vision and can actually see the trace of all +`EXPECT_CALL`s and mock method calls as they are made? For each call, would you +like to see its actual argument values and which `EXPECT_CALL` gMock thinks it +matches? If you still need some help to figure out who made these calls, how +about being able to see the complete stack trace at each mock call? + +You can unlock this power by running your test with the `--gmock_verbose=info` +flag. For example, given the test program: + +```cpp +#include "gmock/gmock.h" + +using testing::_; +using testing::HasSubstr; +using testing::Return; + +class MockFoo { + public: + MOCK_METHOD(void, F, (const string& x, const string& y)); +}; + +TEST(Foo, Bar) { + MockFoo mock; + EXPECT_CALL(mock, F(_, _)).WillRepeatedly(Return()); + EXPECT_CALL(mock, F("a", "b")); + EXPECT_CALL(mock, F("c", HasSubstr("d"))); + + mock.F("a", "good"); + mock.F("a", "b"); +} +``` + +if you run it with `--gmock_verbose=info`, you will see this output: + +```shell +[ RUN ] Foo.Bar + +foo_test.cc:14: EXPECT_CALL(mock, F(_, _)) invoked +Stack trace: ... + +foo_test.cc:15: EXPECT_CALL(mock, F("a", "b")) invoked +Stack trace: ... + +foo_test.cc:16: EXPECT_CALL(mock, F("c", HasSubstr("d"))) invoked +Stack trace: ... + +foo_test.cc:14: Mock function call matches EXPECT_CALL(mock, F(_, _))... + Function call: F(@0x7fff7c8dad40"a",@0x7fff7c8dad10"good") +Stack trace: ... + +foo_test.cc:15: Mock function call matches EXPECT_CALL(mock, F("a", "b"))... + Function call: F(@0x7fff7c8dada0"a",@0x7fff7c8dad70"b") +Stack trace: ... + +foo_test.cc:16: Failure +Actual function call count doesn't match EXPECT_CALL(mock, F("c", HasSubstr("d")))... + Expected: to be called once + Actual: never called - unsatisfied and active +[ FAILED ] Foo.Bar +``` + +Suppose the bug is that the `"c"` in the third `EXPECT_CALL` is a typo and +should actually be `"a"`. With the above message, you should see that the actual +`F("a", "good")` call is matched by the first `EXPECT_CALL`, not the third as +you thought. From that it should be obvious that the third `EXPECT_CALL` is +written wrong. Case solved. + +If you are interested in the mock call trace but not the stack traces, you can +combine `--gmock_verbose=info` with `--gtest_stack_trace_depth=0` on the test +command line. + +### Running Tests in Emacs + +If you build and run your tests in Emacs using the `M-x google-compile` command +(as many googletest users do), the source file locations of gMock and googletest +errors will be highlighted. Just press `` on one of them and you'll be +taken to the offending line. Or, you can just type `C-x`` to jump to the next +error. + +To make it even easier, you can add the following lines to your `~/.emacs` file: + +```text +(global-set-key "\M-m" 'google-compile) ; m is for make +(global-set-key [M-down] 'next-error) +(global-set-key [M-up] '(lambda () (interactive) (next-error -1))) +``` + +Then you can type `M-m` to start a build (if you want to run the test as well, +just make sure `foo_test.run` or `runtests` is in the build command you supply +after typing `M-m`), or `M-up`/`M-down` to move back and forth between errors. + +## Extending gMock + +### Writing New Matchers Quickly {#NewMatchers} + +{: .callout .warning} +WARNING: gMock does not guarantee when or how many times a matcher will be +invoked. Therefore, all matchers must be functionally pure. See +[this section](#PureMatchers) for more details. + +The `MATCHER*` family of macros can be used to define custom matchers easily. +The syntax: + +```cpp +MATCHER(name, description_string_expression) { statements; } +``` + +will define a matcher with the given name that executes the statements, which +must return a `bool` to indicate if the match succeeds. Inside the statements, +you can refer to the value being matched by `arg`, and refer to its type by +`arg_type`. + +The *description string* is a `string`-typed expression that documents what the +matcher does, and is used to generate the failure message when the match fails. +It can (and should) reference the special `bool` variable `negation`, and should +evaluate to the description of the matcher when `negation` is `false`, or that +of the matcher's negation when `negation` is `true`. + +For convenience, we allow the description string to be empty (`""`), in which +case gMock will use the sequence of words in the matcher name as the +description. + +For example: + +```cpp +MATCHER(IsDivisibleBy7, "") { return (arg % 7) == 0; } +``` + +allows you to write + +```cpp + // Expects mock_foo.Bar(n) to be called where n is divisible by 7. + EXPECT_CALL(mock_foo, Bar(IsDivisibleBy7())); +``` + +or, + +```cpp + using ::testing::Not; + ... + // Verifies that a value is divisible by 7 and the other is not. + EXPECT_THAT(some_expression, IsDivisibleBy7()); + EXPECT_THAT(some_other_expression, Not(IsDivisibleBy7())); +``` + +If the above assertions fail, they will print something like: + +```shell + Value of: some_expression + Expected: is divisible by 7 + Actual: 27 + ... + Value of: some_other_expression + Expected: not (is divisible by 7) + Actual: 21 +``` + +where the descriptions `"is divisible by 7"` and `"not (is divisible by 7)"` are +automatically calculated from the matcher name `IsDivisibleBy7`. + +As you may have noticed, the auto-generated descriptions (especially those for +the negation) may not be so great. You can always override them with a `string` +expression of your own: + +```cpp +MATCHER(IsDivisibleBy7, + absl::StrCat(negation ? "isn't" : "is", " divisible by 7")) { + return (arg % 7) == 0; +} +``` + +Optionally, you can stream additional information to a hidden argument named +`result_listener` to explain the match result. For example, a better definition +of `IsDivisibleBy7` is: + +```cpp +MATCHER(IsDivisibleBy7, "") { + if ((arg % 7) == 0) + return true; + + *result_listener << "the remainder is " << (arg % 7); + return false; +} +``` + +With this definition, the above assertion will give a better message: + +```shell + Value of: some_expression + Expected: is divisible by 7 + Actual: 27 (the remainder is 6) +``` + +You should let `MatchAndExplain()` print *any additional information* that can +help a user understand the match result. Note that it should explain why the +match succeeds in case of a success (unless it's obvious) - this is useful when +the matcher is used inside `Not()`. There is no need to print the argument value +itself, as gMock already prints it for you. + +{: .callout .note} +NOTE: The type of the value being matched (`arg_type`) is determined by the +context in which you use the matcher and is supplied to you by the compiler, so +you don't need to worry about declaring it (nor can you). This allows the +matcher to be polymorphic. For example, `IsDivisibleBy7()` can be used to match +any type where the value of `(arg % 7) == 0` can be implicitly converted to a +`bool`. In the `Bar(IsDivisibleBy7())` example above, if method `Bar()` takes an +`int`, `arg_type` will be `int`; if it takes an `unsigned long`, `arg_type` will +be `unsigned long`; and so on. + +### Writing New Parameterized Matchers Quickly + +Sometimes you'll want to define a matcher that has parameters. For that you can +use the macro: + +```cpp +MATCHER_P(name, param_name, description_string) { statements; } +``` + +where the description string can be either `""` or a `string` expression that +references `negation` and `param_name`. + +For example: + +```cpp +MATCHER_P(HasAbsoluteValue, value, "") { return abs(arg) == value; } +``` + +will allow you to write: + +```cpp + EXPECT_THAT(Blah("a"), HasAbsoluteValue(n)); +``` + +which may lead to this message (assuming `n` is 10): + +```shell + Value of: Blah("a") + Expected: has absolute value 10 + Actual: -9 +``` + +Note that both the matcher description and its parameter are printed, making the +message human-friendly. + +In the matcher definition body, you can write `foo_type` to reference the type +of a parameter named `foo`. For example, in the body of +`MATCHER_P(HasAbsoluteValue, value)` above, you can write `value_type` to refer +to the type of `value`. + +gMock also provides `MATCHER_P2`, `MATCHER_P3`, ..., up to `MATCHER_P10` to +support multi-parameter matchers: + +```cpp +MATCHER_Pk(name, param_1, ..., param_k, description_string) { statements; } +``` + +Please note that the custom description string is for a particular *instance* of +the matcher, where the parameters have been bound to actual values. Therefore +usually you'll want the parameter values to be part of the description. gMock +lets you do that by referencing the matcher parameters in the description string +expression. + +For example, + +```cpp +using ::testing::PrintToString; +MATCHER_P2(InClosedRange, low, hi, + absl::StrFormat("%s in range [%s, %s]", negation ? "isn't" : "is", + PrintToString(low), PrintToString(hi))) { + return low <= arg && arg <= hi; +} +... +EXPECT_THAT(3, InClosedRange(4, 6)); +``` + +would generate a failure that contains the message: + +```shell + Expected: is in range [4, 6] +``` + +If you specify `""` as the description, the failure message will contain the +sequence of words in the matcher name followed by the parameter values printed +as a tuple. For example, + +```cpp + MATCHER_P2(InClosedRange, low, hi, "") { ... } + ... + EXPECT_THAT(3, InClosedRange(4, 6)); +``` + +would generate a failure that contains the text: + +```shell + Expected: in closed range (4, 6) +``` + +For the purpose of typing, you can view + +```cpp +MATCHER_Pk(Foo, p1, ..., pk, description_string) { ... } +``` + +as shorthand for + +```cpp +template +FooMatcherPk +Foo(p1_type p1, ..., pk_type pk) { ... } +``` + +When you write `Foo(v1, ..., vk)`, the compiler infers the types of the +parameters `v1`, ..., and `vk` for you. If you are not happy with the result of +the type inference, you can specify the types by explicitly instantiating the +template, as in `Foo(5, false)`. As said earlier, you don't get to +(or need to) specify `arg_type` as that's determined by the context in which the +matcher is used. + +You can assign the result of expression `Foo(p1, ..., pk)` to a variable of type +`FooMatcherPk`. This can be useful when composing +matchers. Matchers that don't have a parameter or have only one parameter have +special types: you can assign `Foo()` to a `FooMatcher`-typed variable, and +assign `Foo(p)` to a `FooMatcherP`-typed variable. + +While you can instantiate a matcher template with reference types, passing the +parameters by pointer usually makes your code more readable. If, however, you +still want to pass a parameter by reference, be aware that in the failure +message generated by the matcher you will see the value of the referenced object +but not its address. + +You can overload matchers with different numbers of parameters: + +```cpp +MATCHER_P(Blah, a, description_string_1) { ... } +MATCHER_P2(Blah, a, b, description_string_2) { ... } +``` + +While it's tempting to always use the `MATCHER*` macros when defining a new +matcher, you should also consider implementing the matcher interface directly +instead (see the recipes that follow), especially if you need to use the matcher +a lot. While these approaches require more work, they give you more control on +the types of the value being matched and the matcher parameters, which in +general leads to better compiler error messages that pay off in the long run. +They also allow overloading matchers based on parameter types (as opposed to +just based on the number of parameters). + +### Writing New Monomorphic Matchers + +A matcher of argument type `T` implements the matcher interface for `T` and does +two things: it tests whether a value of type `T` matches the matcher, and can +describe what kind of values it matches. The latter ability is used for +generating readable error messages when expectations are violated. + +A matcher of `T` must declare a typedef like: + +```cpp +using is_gtest_matcher = void; +``` + +and supports the following operations: + +```cpp +// Match a value and optionally explain into an ostream. +bool matched = matcher.MatchAndExplain(value, maybe_os); +// where `value` is of type `T` and +// `maybe_os` is of type `std::ostream*`, where it can be null if the caller +// is not interested in there textual explanation. + +matcher.DescribeTo(os); +matcher.DescribeNegationTo(os); +// where `os` is of type `std::ostream*`. +``` + +If you need a custom matcher but `Truly()` is not a good option (for example, +you may not be happy with the way `Truly(predicate)` describes itself, or you +may want your matcher to be polymorphic as `Eq(value)` is), you can define a +matcher to do whatever you want in two steps: first implement the matcher +interface, and then define a factory function to create a matcher instance. The +second step is not strictly needed but it makes the syntax of using the matcher +nicer. + +For example, you can define a matcher to test whether an `int` is divisible by 7 +and then use it like this: + +```cpp +using ::testing::Matcher; + +class DivisibleBy7Matcher { + public: + using is_gtest_matcher = void; + + bool MatchAndExplain(int n, std::ostream*) const { + return (n % 7) == 0; + } + + void DescribeTo(std::ostream* os) const { + *os << "is divisible by 7"; + } + + void DescribeNegationTo(std::ostream* os) const { + *os << "is not divisible by 7"; + } +}; + +Matcher DivisibleBy7() { + return DivisibleBy7Matcher(); +} + +... + EXPECT_CALL(foo, Bar(DivisibleBy7())); +``` + +You may improve the matcher message by streaming additional information to the +`os` argument in `MatchAndExplain()`: + +```cpp +class DivisibleBy7Matcher { + public: + bool MatchAndExplain(int n, std::ostream* os) const { + const int remainder = n % 7; + if (remainder != 0 && os != nullptr) { + *os << "the remainder is " << remainder; + } + return remainder == 0; + } + ... +}; +``` + +Then, `EXPECT_THAT(x, DivisibleBy7());` may generate a message like this: + +```shell +Value of: x +Expected: is divisible by 7 + Actual: 23 (the remainder is 2) +``` + +{: .callout .tip} +Tip: for convenience, `MatchAndExplain()` can take a `MatchResultListener*` +instead of `std::ostream*`. + +### Writing New Polymorphic Matchers + +Expanding what we learned above to *polymorphic* matchers is now just as simple +as adding templates in the right place. + +```cpp + +class NotNullMatcher { + public: + using is_gtest_matcher = void; + + // To implement a polymorphic matcher, we just need to make MatchAndExplain a + // template on its first argument. + + // In this example, we want to use NotNull() with any pointer, so + // MatchAndExplain() accepts a pointer of any type as its first argument. + // In general, you can define MatchAndExplain() as an ordinary method or + // a method template, or even overload it. + template + bool MatchAndExplain(T* p, std::ostream*) const { + return p != nullptr; + } + + // Describes the property of a value matching this matcher. + void DescribeTo(std::ostream* os) const { *os << "is not NULL"; } + + // Describes the property of a value NOT matching this matcher. + void DescribeNegationTo(std::ostream* os) const { *os << "is NULL"; } +}; + +NotNullMatcher NotNull() { + return NotNullMatcher(); +} + +... + + EXPECT_CALL(foo, Bar(NotNull())); // The argument must be a non-NULL pointer. +``` + +### Legacy Matcher Implementation + +Defining matchers used to be somewhat more complicated, in which it required +several supporting classes and virtual functions. To implement a matcher for +type `T` using the legacy API you have to derive from `MatcherInterface` and +call `MakeMatcher` to construct the object. + +The interface looks like this: + +```cpp +class MatchResultListener { + public: + ... + // Streams x to the underlying ostream; does nothing if the ostream + // is NULL. + template + MatchResultListener& operator<<(const T& x); + + // Returns the underlying ostream. + std::ostream* stream(); +}; + +template +class MatcherInterface { + public: + virtual ~MatcherInterface(); + + // Returns true if and only if the matcher matches x; also explains the match + // result to 'listener'. + virtual bool MatchAndExplain(T x, MatchResultListener* listener) const = 0; + + // Describes this matcher to an ostream. + virtual void DescribeTo(std::ostream* os) const = 0; + + // Describes the negation of this matcher to an ostream. + virtual void DescribeNegationTo(std::ostream* os) const; +}; +``` + +Fortunately, most of the time you can define a polymorphic matcher easily with +the help of `MakePolymorphicMatcher()`. Here's how you can define `NotNull()` as +an example: + +```cpp +using ::testing::MakePolymorphicMatcher; +using ::testing::MatchResultListener; +using ::testing::PolymorphicMatcher; + +class NotNullMatcher { + public: + // To implement a polymorphic matcher, first define a COPYABLE class + // that has three members MatchAndExplain(), DescribeTo(), and + // DescribeNegationTo(), like the following. + + // In this example, we want to use NotNull() with any pointer, so + // MatchAndExplain() accepts a pointer of any type as its first argument. + // In general, you can define MatchAndExplain() as an ordinary method or + // a method template, or even overload it. + template + bool MatchAndExplain(T* p, + MatchResultListener* /* listener */) const { + return p != NULL; + } + + // Describes the property of a value matching this matcher. + void DescribeTo(std::ostream* os) const { *os << "is not NULL"; } + + // Describes the property of a value NOT matching this matcher. + void DescribeNegationTo(std::ostream* os) const { *os << "is NULL"; } +}; + +// To construct a polymorphic matcher, pass an instance of the class +// to MakePolymorphicMatcher(). Note the return type. +PolymorphicMatcher NotNull() { + return MakePolymorphicMatcher(NotNullMatcher()); +} + +... + + EXPECT_CALL(foo, Bar(NotNull())); // The argument must be a non-NULL pointer. +``` + +{: .callout .note} +**Note:** Your polymorphic matcher class does **not** need to inherit from +`MatcherInterface` or any other class, and its methods do **not** need to be +virtual. + +Like in a monomorphic matcher, you may explain the match result by streaming +additional information to the `listener` argument in `MatchAndExplain()`. + +### Writing New Cardinalities + +A cardinality is used in `Times()` to tell gMock how many times you expect a +call to occur. It doesn't have to be exact. For example, you can say +`AtLeast(5)` or `Between(2, 4)`. + +If the [built-in set](gmock_cheat_sheet.md#CardinalityList) of cardinalities +doesn't suit you, you are free to define your own by implementing the following +interface (in namespace `testing`): + +```cpp +class CardinalityInterface { + public: + virtual ~CardinalityInterface(); + + // Returns true if and only if call_count calls will satisfy this cardinality. + virtual bool IsSatisfiedByCallCount(int call_count) const = 0; + + // Returns true if and only if call_count calls will saturate this + // cardinality. + virtual bool IsSaturatedByCallCount(int call_count) const = 0; + + // Describes self to an ostream. + virtual void DescribeTo(std::ostream* os) const = 0; +}; +``` + +For example, to specify that a call must occur even number of times, you can +write + +```cpp +using ::testing::Cardinality; +using ::testing::CardinalityInterface; +using ::testing::MakeCardinality; + +class EvenNumberCardinality : public CardinalityInterface { + public: + bool IsSatisfiedByCallCount(int call_count) const override { + return (call_count % 2) == 0; + } + + bool IsSaturatedByCallCount(int call_count) const override { + return false; + } + + void DescribeTo(std::ostream* os) const { + *os << "called even number of times"; + } +}; + +Cardinality EvenNumber() { + return MakeCardinality(new EvenNumberCardinality); +} + +... + EXPECT_CALL(foo, Bar(3)) + .Times(EvenNumber()); +``` + +### Writing New Actions {#QuickNewActions} + +If the built-in actions don't work for you, you can easily define your own one. +All you need is a call operator with a signature compatible with the mocked +function. So you can use a lambda: + +``` +MockFunction mock; +EXPECT_CALL(mock, Call).WillOnce([](const int input) { return input * 7; }); +EXPECT_EQ(14, mock.AsStdFunction()(2)); +``` + +Or a struct with a call operator (even a templated one): + +``` +struct MultiplyBy { + template + T operator()(T arg) { return arg * multiplier; } + + int multiplier; +}; + +// Then use: +// EXPECT_CALL(...).WillOnce(MultiplyBy{7}); +``` + +It's also fine for the callable to take no arguments, ignoring the arguments +supplied to the mock function: + +``` +MockFunction mock; +EXPECT_CALL(mock, Call).WillOnce([] { return 17; }); +EXPECT_EQ(17, mock.AsStdFunction()(0)); +``` + +When used with `WillOnce`, the callable can assume it will be called at most +once and is allowed to be a move-only type: + +``` +// An action that contains move-only types and has an &&-qualified operator, +// demanding in the type system that it be called at most once. This can be +// used with WillOnce, but the compiler will reject it if handed to +// WillRepeatedly. +struct MoveOnlyAction { + std::unique_ptr move_only_state; + std::unique_ptr operator()() && { return std::move(move_only_state); } +}; + +MockFunction()> mock; +EXPECT_CALL(mock, Call).WillOnce(MoveOnlyAction{std::make_unique(17)}); +EXPECT_THAT(mock.AsStdFunction()(), Pointee(Eq(17))); +``` + +More generally, to use with a mock function whose signature is `R(Args...)` the +object can be anything convertible to `OnceAction` or +`Action. The difference between the two is that `OnceAction` has +weaker requirements (`Action` requires a copy-constructible input that can be +called repeatedly whereas `OnceAction` requires only move-constructible and +supports `&&`-qualified call operators), but can be used only with `WillOnce`. +`OnceAction` is typically relevant only when supporting move-only types or +actions that want a type-system guarantee that they will be called at most once. + +Typically the `OnceAction` and `Action` templates need not be referenced +directly in your actions: a struct or class with a call operator is sufficient, +as in the examples above. But fancier polymorphic actions that need to know the +specific return type of the mock function can define templated conversion +operators to make that possible. See `gmock-actions.h` for examples. + +#### Legacy macro-based Actions + +Before C++11, the functor-based actions were not supported; the old way of +writing actions was through a set of `ACTION*` macros. We suggest to avoid them +in new code; they hide a lot of logic behind the macro, potentially leading to +harder-to-understand compiler errors. Nevertheless, we cover them here for +completeness. + +By writing + +```cpp +ACTION(name) { statements; } +``` + +in a namespace scope (i.e. not inside a class or function), you will define an +action with the given name that executes the statements. The value returned by +`statements` will be used as the return value of the action. Inside the +statements, you can refer to the K-th (0-based) argument of the mock function as +`argK`. For example: + +```cpp +ACTION(IncrementArg1) { return ++(*arg1); } +``` + +allows you to write + +```cpp +... WillOnce(IncrementArg1()); +``` + +Note that you don't need to specify the types of the mock function arguments. +Rest assured that your code is type-safe though: you'll get a compiler error if +`*arg1` doesn't support the `++` operator, or if the type of `++(*arg1)` isn't +compatible with the mock function's return type. + +Another example: + +```cpp +ACTION(Foo) { + (*arg2)(5); + Blah(); + *arg1 = 0; + return arg0; +} +``` + +defines an action `Foo()` that invokes argument #2 (a function pointer) with 5, +calls function `Blah()`, sets the value pointed to by argument #1 to 0, and +returns argument #0. + +For more convenience and flexibility, you can also use the following pre-defined +symbols in the body of `ACTION`: + +`argK_type` | The type of the K-th (0-based) argument of the mock function +:-------------- | :----------------------------------------------------------- +`args` | All arguments of the mock function as a tuple +`args_type` | The type of all arguments of the mock function as a tuple +`return_type` | The return type of the mock function +`function_type` | The type of the mock function + +For example, when using an `ACTION` as a stub action for mock function: + +```cpp +int DoSomething(bool flag, int* ptr); +``` + +we have: + +Pre-defined Symbol | Is Bound To +------------------ | --------------------------------- +`arg0` | the value of `flag` +`arg0_type` | the type `bool` +`arg1` | the value of `ptr` +`arg1_type` | the type `int*` +`args` | the tuple `(flag, ptr)` +`args_type` | the type `std::tuple` +`return_type` | the type `int` +`function_type` | the type `int(bool, int*)` + +#### Legacy macro-based parameterized Actions + +Sometimes you'll want to parameterize an action you define. For that we have +another macro + +```cpp +ACTION_P(name, param) { statements; } +``` + +For example, + +```cpp +ACTION_P(Add, n) { return arg0 + n; } +``` + +will allow you to write + +```cpp +// Returns argument #0 + 5. +... WillOnce(Add(5)); +``` + +For convenience, we use the term *arguments* for the values used to invoke the +mock function, and the term *parameters* for the values used to instantiate an +action. + +Note that you don't need to provide the type of the parameter either. Suppose +the parameter is named `param`, you can also use the gMock-defined symbol +`param_type` to refer to the type of the parameter as inferred by the compiler. +For example, in the body of `ACTION_P(Add, n)` above, you can write `n_type` for +the type of `n`. + +gMock also provides `ACTION_P2`, `ACTION_P3`, and etc to support multi-parameter +actions. For example, + +```cpp +ACTION_P2(ReturnDistanceTo, x, y) { + double dx = arg0 - x; + double dy = arg1 - y; + return sqrt(dx*dx + dy*dy); +} +``` + +lets you write + +```cpp +... WillOnce(ReturnDistanceTo(5.0, 26.5)); +``` + +You can view `ACTION` as a degenerated parameterized action where the number of +parameters is 0. + +You can also easily define actions overloaded on the number of parameters: + +```cpp +ACTION_P(Plus, a) { ... } +ACTION_P2(Plus, a, b) { ... } +``` + +### Restricting the Type of an Argument or Parameter in an ACTION + +For maximum brevity and reusability, the `ACTION*` macros don't ask you to +provide the types of the mock function arguments and the action parameters. +Instead, we let the compiler infer the types for us. + +Sometimes, however, we may want to be more explicit about the types. There are +several tricks to do that. For example: + +```cpp +ACTION(Foo) { + // Makes sure arg0 can be converted to int. + int n = arg0; + ... use n instead of arg0 here ... +} + +ACTION_P(Bar, param) { + // Makes sure the type of arg1 is const char*. + ::testing::StaticAssertTypeEq(); + + // Makes sure param can be converted to bool. + bool flag = param; +} +``` + +where `StaticAssertTypeEq` is a compile-time assertion in googletest that +verifies two types are the same. + +### Writing New Action Templates Quickly + +Sometimes you want to give an action explicit template parameters that cannot be +inferred from its value parameters. `ACTION_TEMPLATE()` supports that and can be +viewed as an extension to `ACTION()` and `ACTION_P*()`. + +The syntax: + +```cpp +ACTION_TEMPLATE(ActionName, + HAS_m_TEMPLATE_PARAMS(kind1, name1, ..., kind_m, name_m), + AND_n_VALUE_PARAMS(p1, ..., p_n)) { statements; } +``` + +defines an action template that takes *m* explicit template parameters and *n* +value parameters, where *m* is in [1, 10] and *n* is in [0, 10]. `name_i` is the +name of the *i*-th template parameter, and `kind_i` specifies whether it's a +`typename`, an integral constant, or a template. `p_i` is the name of the *i*-th +value parameter. + +Example: + +```cpp +// DuplicateArg(output) converts the k-th argument of the mock +// function to type T and copies it to *output. +ACTION_TEMPLATE(DuplicateArg, + // Note the comma between int and k: + HAS_2_TEMPLATE_PARAMS(int, k, typename, T), + AND_1_VALUE_PARAMS(output)) { + *output = T(std::get(args)); +} +``` + +To create an instance of an action template, write: + +```cpp +ActionName(v1, ..., v_n) +``` + +where the `t`s are the template arguments and the `v`s are the value arguments. +The value argument types are inferred by the compiler. For example: + +```cpp +using ::testing::_; +... + int n; + EXPECT_CALL(mock, Foo).WillOnce(DuplicateArg<1, unsigned char>(&n)); +``` + +If you want to explicitly specify the value argument types, you can provide +additional template arguments: + +```cpp +ActionName(v1, ..., v_n) +``` + +where `u_i` is the desired type of `v_i`. + +`ACTION_TEMPLATE` and `ACTION`/`ACTION_P*` can be overloaded on the number of +value parameters, but not on the number of template parameters. Without the +restriction, the meaning of the following is unclear: + +```cpp + OverloadedAction(x); +``` + +Are we using a single-template-parameter action where `bool` refers to the type +of `x`, or a two-template-parameter action where the compiler is asked to infer +the type of `x`? + +### Using the ACTION Object's Type + +If you are writing a function that returns an `ACTION` object, you'll need to +know its type. The type depends on the macro used to define the action and the +parameter types. The rule is relatively simple: + + +| Given Definition | Expression | Has Type | +| ----------------------------- | ------------------- | --------------------- | +| `ACTION(Foo)` | `Foo()` | `FooAction` | +| `ACTION_TEMPLATE(Foo, HAS_m_TEMPLATE_PARAMS(...), AND_0_VALUE_PARAMS())` | `Foo()` | `FooAction` | +| `ACTION_P(Bar, param)` | `Bar(int_value)` | `BarActionP` | +| `ACTION_TEMPLATE(Bar, HAS_m_TEMPLATE_PARAMS(...), AND_1_VALUE_PARAMS(p1))` | `Bar(int_value)` | `BarActionP` | +| `ACTION_P2(Baz, p1, p2)` | `Baz(bool_value, int_value)` | `BazActionP2` | +| `ACTION_TEMPLATE(Baz, HAS_m_TEMPLATE_PARAMS(...), AND_2_VALUE_PARAMS(p1, p2))` | `Baz(bool_value, int_value)` | `BazActionP2` | +| ... | ... | ... | + + +Note that we have to pick different suffixes (`Action`, `ActionP`, `ActionP2`, +and etc) for actions with different numbers of value parameters, or the action +definitions cannot be overloaded on the number of them. + +### Writing New Monomorphic Actions {#NewMonoActions} + +While the `ACTION*` macros are very convenient, sometimes they are +inappropriate. For example, despite the tricks shown in the previous recipes, +they don't let you directly specify the types of the mock function arguments and +the action parameters, which in general leads to unoptimized compiler error +messages that can baffle unfamiliar users. They also don't allow overloading +actions based on parameter types without jumping through some hoops. + +An alternative to the `ACTION*` macros is to implement +`::testing::ActionInterface`, where `F` is the type of the mock function in +which the action will be used. For example: + +```cpp +template +class ActionInterface { + public: + virtual ~ActionInterface(); + + // Performs the action. Result is the return type of function type + // F, and ArgumentTuple is the tuple of arguments of F. + // + + // For example, if F is int(bool, const string&), then Result would + // be int, and ArgumentTuple would be std::tuple. + virtual Result Perform(const ArgumentTuple& args) = 0; +}; +``` + +```cpp +using ::testing::_; +using ::testing::Action; +using ::testing::ActionInterface; +using ::testing::MakeAction; + +typedef int IncrementMethod(int*); + +class IncrementArgumentAction : public ActionInterface { + public: + int Perform(const std::tuple& args) override { + int* p = std::get<0>(args); // Grabs the first argument. + return *p++; + } +}; + +Action IncrementArgument() { + return MakeAction(new IncrementArgumentAction); +} + +... + EXPECT_CALL(foo, Baz(_)) + .WillOnce(IncrementArgument()); + + int n = 5; + foo.Baz(&n); // Should return 5 and change n to 6. +``` + +### Writing New Polymorphic Actions {#NewPolyActions} + +The previous recipe showed you how to define your own action. This is all good, +except that you need to know the type of the function in which the action will +be used. Sometimes that can be a problem. For example, if you want to use the +action in functions with *different* types (e.g. like `Return()` and +`SetArgPointee()`). + +If an action can be used in several types of mock functions, we say it's +*polymorphic*. The `MakePolymorphicAction()` function template makes it easy to +define such an action: + +```cpp +namespace testing { +template +PolymorphicAction MakePolymorphicAction(const Impl& impl); +} // namespace testing +``` + +As an example, let's define an action that returns the second argument in the +mock function's argument list. The first step is to define an implementation +class: + +```cpp +class ReturnSecondArgumentAction { + public: + template + Result Perform(const ArgumentTuple& args) const { + // To get the i-th (0-based) argument, use std::get(args). + return std::get<1>(args); + } +}; +``` + +This implementation class does *not* need to inherit from any particular class. +What matters is that it must have a `Perform()` method template. This method +template takes the mock function's arguments as a tuple in a **single** +argument, and returns the result of the action. It can be either `const` or not, +but must be invocable with exactly one template argument, which is the result +type. In other words, you must be able to call `Perform(args)` where `R` is +the mock function's return type and `args` is its arguments in a tuple. + +Next, we use `MakePolymorphicAction()` to turn an instance of the implementation +class into the polymorphic action we need. It will be convenient to have a +wrapper for this: + +```cpp +using ::testing::MakePolymorphicAction; +using ::testing::PolymorphicAction; + +PolymorphicAction ReturnSecondArgument() { + return MakePolymorphicAction(ReturnSecondArgumentAction()); +} +``` + +Now, you can use this polymorphic action the same way you use the built-in ones: + +```cpp +using ::testing::_; + +class MockFoo : public Foo { + public: + MOCK_METHOD(int, DoThis, (bool flag, int n), (override)); + MOCK_METHOD(string, DoThat, (int x, const char* str1, const char* str2), + (override)); +}; + + ... + MockFoo foo; + EXPECT_CALL(foo, DoThis).WillOnce(ReturnSecondArgument()); + EXPECT_CALL(foo, DoThat).WillOnce(ReturnSecondArgument()); + ... + foo.DoThis(true, 5); // Will return 5. + foo.DoThat(1, "Hi", "Bye"); // Will return "Hi". +``` + +### Teaching gMock How to Print Your Values + +When an uninteresting or unexpected call occurs, gMock prints the argument +values and the stack trace to help you debug. Assertion macros like +`EXPECT_THAT` and `EXPECT_EQ` also print the values in question when the +assertion fails. gMock and googletest do this using googletest's user-extensible +value printer. + +This printer knows how to print built-in C++ types, native arrays, STL +containers, and any type that supports the `<<` operator. For other types, it +prints the raw bytes in the value and hopes that you the user can figure it out. +[The GoogleTest advanced guide](advanced.md#teaching-googletest-how-to-print-your-values) +explains how to extend the printer to do a better job at printing your +particular type than to dump the bytes. + +## Useful Mocks Created Using gMock + + + + +### Mock std::function {#MockFunction} + +`std::function` is a general function type introduced in C++11. It is a +preferred way of passing callbacks to new interfaces. Functions are copiable, +and are not usually passed around by pointer, which makes them tricky to mock. +But fear not - `MockFunction` can help you with that. + +`MockFunction` has a mock method `Call()` with the signature: + +```cpp + R Call(T1, ..., Tn); +``` + +It also has a `AsStdFunction()` method, which creates a `std::function` proxy +forwarding to Call: + +```cpp + std::function AsStdFunction(); +``` + +To use `MockFunction`, first create `MockFunction` object and set up +expectations on its `Call` method. Then pass proxy obtained from +`AsStdFunction()` to the code you are testing. For example: + +```cpp +TEST(FooTest, RunsCallbackWithBarArgument) { + // 1. Create a mock object. + MockFunction mock_function; + + // 2. Set expectations on Call() method. + EXPECT_CALL(mock_function, Call("bar")).WillOnce(Return(1)); + + // 3. Exercise code that uses std::function. + Foo(mock_function.AsStdFunction()); + // Foo's signature can be either of: + // void Foo(const std::function& fun); + // void Foo(std::function fun); + + // 4. All expectations will be verified when mock_function + // goes out of scope and is destroyed. +} +``` + +Remember that function objects created with `AsStdFunction()` are just +forwarders. If you create multiple of them, they will share the same set of +expectations. + +Although `std::function` supports unlimited number of arguments, `MockFunction` +implementation is limited to ten. If you ever hit that limit... well, your +callback has bigger problems than being mockable. :-) diff --git a/3rdparty/googletest-1.13.0/docs/gmock_faq.md b/3rdparty/googletest-1.13.0/docs/gmock_faq.md new file mode 100644 index 0000000000000000000000000000000000000000..8f220bf7a8fec033ed9cb827a794397315962fcc --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/gmock_faq.md @@ -0,0 +1,390 @@ +# Legacy gMock FAQ + +### When I call a method on my mock object, the method for the real object is invoked instead. What's the problem? + +In order for a method to be mocked, it must be *virtual*, unless you use the +[high-perf dependency injection technique](gmock_cook_book.md#MockingNonVirtualMethods). + +### Can I mock a variadic function? + +You cannot mock a variadic function (i.e. a function taking ellipsis (`...`) +arguments) directly in gMock. + +The problem is that in general, there is *no way* for a mock object to know how +many arguments are passed to the variadic method, and what the arguments' types +are. Only the *author of the base class* knows the protocol, and we cannot look +into his or her head. + +Therefore, to mock such a function, the *user* must teach the mock object how to +figure out the number of arguments and their types. One way to do it is to +provide overloaded versions of the function. + +Ellipsis arguments are inherited from C and not really a C++ feature. They are +unsafe to use and don't work with arguments that have constructors or +destructors. Therefore we recommend to avoid them in C++ as much as possible. + +### MSVC gives me warning C4301 or C4373 when I define a mock method with a const parameter. Why? + +If you compile this using Microsoft Visual C++ 2005 SP1: + +```cpp +class Foo { + ... + virtual void Bar(const int i) = 0; +}; + +class MockFoo : public Foo { + ... + MOCK_METHOD(void, Bar, (const int i), (override)); +}; +``` + +You may get the following warning: + +```shell +warning C4301: 'MockFoo::Bar': overriding virtual function only differs from 'Foo::Bar' by const/volatile qualifier +``` + +This is a MSVC bug. The same code compiles fine with gcc, for example. If you +use Visual C++ 2008 SP1, you would get the warning: + +```shell +warning C4373: 'MockFoo::Bar': virtual function overrides 'Foo::Bar', previous versions of the compiler did not override when parameters only differed by const/volatile qualifiers +``` + +In C++, if you *declare* a function with a `const` parameter, the `const` +modifier is ignored. Therefore, the `Foo` base class above is equivalent to: + +```cpp +class Foo { + ... + virtual void Bar(int i) = 0; // int or const int? Makes no difference. +}; +``` + +In fact, you can *declare* `Bar()` with an `int` parameter, and define it with a +`const int` parameter. The compiler will still match them up. + +Since making a parameter `const` is meaningless in the method declaration, we +recommend to remove it in both `Foo` and `MockFoo`. That should workaround the +VC bug. + +Note that we are talking about the *top-level* `const` modifier here. If the +function parameter is passed by pointer or reference, declaring the pointee or +referee as `const` is still meaningful. For example, the following two +declarations are *not* equivalent: + +```cpp +void Bar(int* p); // Neither p nor *p is const. +void Bar(const int* p); // p is not const, but *p is. +``` + +### I can't figure out why gMock thinks my expectations are not satisfied. What should I do? + +You might want to run your test with `--gmock_verbose=info`. This flag lets +gMock print a trace of every mock function call it receives. By studying the +trace, you'll gain insights on why the expectations you set are not met. + +If you see the message "The mock function has no default action set, and its +return type has no default value set.", then try +[adding a default action](gmock_cheat_sheet.md#OnCall). Due to a known issue, +unexpected calls on mocks without default actions don't print out a detailed +comparison between the actual arguments and the expected arguments. + +### My program crashed and `ScopedMockLog` spit out tons of messages. Is it a gMock bug? + +gMock and `ScopedMockLog` are likely doing the right thing here. + +When a test crashes, the failure signal handler will try to log a lot of +information (the stack trace, and the address map, for example). The messages +are compounded if you have many threads with depth stacks. When `ScopedMockLog` +intercepts these messages and finds that they don't match any expectations, it +prints an error for each of them. + +You can learn to ignore the errors, or you can rewrite your expectations to make +your test more robust, for example, by adding something like: + +```cpp +using ::testing::AnyNumber; +using ::testing::Not; +... + // Ignores any log not done by us. + EXPECT_CALL(log, Log(_, Not(EndsWith("/my_file.cc")), _)) + .Times(AnyNumber()); +``` + +### How can I assert that a function is NEVER called? + +```cpp +using ::testing::_; +... + EXPECT_CALL(foo, Bar(_)) + .Times(0); +``` + +### I have a failed test where gMock tells me TWICE that a particular expectation is not satisfied. Isn't this redundant? + +When gMock detects a failure, it prints relevant information (the mock function +arguments, the state of relevant expectations, and etc) to help the user debug. +If another failure is detected, gMock will do the same, including printing the +state of relevant expectations. + +Sometimes an expectation's state didn't change between two failures, and you'll +see the same description of the state twice. They are however *not* redundant, +as they refer to *different points in time*. The fact they are the same *is* +interesting information. + +### I get a heapcheck failure when using a mock object, but using a real object is fine. What can be wrong? + +Does the class (hopefully a pure interface) you are mocking have a virtual +destructor? + +Whenever you derive from a base class, make sure its destructor is virtual. +Otherwise Bad Things will happen. Consider the following code: + +```cpp +class Base { + public: + // Not virtual, but should be. + ~Base() { ... } + ... +}; + +class Derived : public Base { + public: + ... + private: + std::string value_; +}; + +... + Base* p = new Derived; + ... + delete p; // Surprise! ~Base() will be called, but ~Derived() will not + // - value_ is leaked. +``` + +By changing `~Base()` to virtual, `~Derived()` will be correctly called when +`delete p` is executed, and the heap checker will be happy. + +### The "newer expectations override older ones" rule makes writing expectations awkward. Why does gMock do that? + +When people complain about this, often they are referring to code like: + +```cpp +using ::testing::Return; +... + // foo.Bar() should be called twice, return 1 the first time, and return + // 2 the second time. However, I have to write the expectations in the + // reverse order. This sucks big time!!! + EXPECT_CALL(foo, Bar()) + .WillOnce(Return(2)) + .RetiresOnSaturation(); + EXPECT_CALL(foo, Bar()) + .WillOnce(Return(1)) + .RetiresOnSaturation(); +``` + +The problem, is that they didn't pick the **best** way to express the test's +intent. + +By default, expectations don't have to be matched in *any* particular order. If +you want them to match in a certain order, you need to be explicit. This is +gMock's (and jMock's) fundamental philosophy: it's easy to accidentally +over-specify your tests, and we want to make it harder to do so. + +There are two better ways to write the test spec. You could either put the +expectations in sequence: + +```cpp +using ::testing::Return; +... + // foo.Bar() should be called twice, return 1 the first time, and return + // 2 the second time. Using a sequence, we can write the expectations + // in their natural order. + { + InSequence s; + EXPECT_CALL(foo, Bar()) + .WillOnce(Return(1)) + .RetiresOnSaturation(); + EXPECT_CALL(foo, Bar()) + .WillOnce(Return(2)) + .RetiresOnSaturation(); + } +``` + +or you can put the sequence of actions in the same expectation: + +```cpp +using ::testing::Return; +... + // foo.Bar() should be called twice, return 1 the first time, and return + // 2 the second time. + EXPECT_CALL(foo, Bar()) + .WillOnce(Return(1)) + .WillOnce(Return(2)) + .RetiresOnSaturation(); +``` + +Back to the original questions: why does gMock search the expectations (and +`ON_CALL`s) from back to front? Because this allows a user to set up a mock's +behavior for the common case early (e.g. in the mock's constructor or the test +fixture's set-up phase) and customize it with more specific rules later. If +gMock searches from front to back, this very useful pattern won't be possible. + +### gMock prints a warning when a function without EXPECT_CALL is called, even if I have set its behavior using ON_CALL. Would it be reasonable not to show the warning in this case? + +When choosing between being neat and being safe, we lean toward the latter. So +the answer is that we think it's better to show the warning. + +Often people write `ON_CALL`s in the mock object's constructor or `SetUp()`, as +the default behavior rarely changes from test to test. Then in the test body +they set the expectations, which are often different for each test. Having an +`ON_CALL` in the set-up part of a test doesn't mean that the calls are expected. +If there's no `EXPECT_CALL` and the method is called, it's possibly an error. If +we quietly let the call go through without notifying the user, bugs may creep in +unnoticed. + +If, however, you are sure that the calls are OK, you can write + +```cpp +using ::testing::_; +... + EXPECT_CALL(foo, Bar(_)) + .WillRepeatedly(...); +``` + +instead of + +```cpp +using ::testing::_; +... + ON_CALL(foo, Bar(_)) + .WillByDefault(...); +``` + +This tells gMock that you do expect the calls and no warning should be printed. + +Also, you can control the verbosity by specifying `--gmock_verbose=error`. Other +values are `info` and `warning`. If you find the output too noisy when +debugging, just choose a less verbose level. + +### How can I delete the mock function's argument in an action? + +If your mock function takes a pointer argument and you want to delete that +argument, you can use testing::DeleteArg() to delete the N'th (zero-indexed) +argument: + +```cpp +using ::testing::_; + ... + MOCK_METHOD(void, Bar, (X* x, const Y& y)); + ... + EXPECT_CALL(mock_foo_, Bar(_, _)) + .WillOnce(testing::DeleteArg<0>())); +``` + +### How can I perform an arbitrary action on a mock function's argument? + +If you find yourself needing to perform some action that's not supported by +gMock directly, remember that you can define your own actions using +[`MakeAction()`](#NewMonoActions) or +[`MakePolymorphicAction()`](#NewPolyActions), or you can write a stub function +and invoke it using [`Invoke()`](#FunctionsAsActions). + +```cpp +using ::testing::_; +using ::testing::Invoke; + ... + MOCK_METHOD(void, Bar, (X* p)); + ... + EXPECT_CALL(mock_foo_, Bar(_)) + .WillOnce(Invoke(MyAction(...))); +``` + +### My code calls a static/global function. Can I mock it? + +You can, but you need to make some changes. + +In general, if you find yourself needing to mock a static function, it's a sign +that your modules are too tightly coupled (and less flexible, less reusable, +less testable, etc). You are probably better off defining a small interface and +call the function through that interface, which then can be easily mocked. It's +a bit of work initially, but usually pays for itself quickly. + +This Google Testing Blog +[post](https://testing.googleblog.com/2008/06/defeat-static-cling.html) says it +excellently. Check it out. + +### My mock object needs to do complex stuff. It's a lot of pain to specify the actions. gMock sucks! + +I know it's not a question, but you get an answer for free any way. :-) + +With gMock, you can create mocks in C++ easily. And people might be tempted to +use them everywhere. Sometimes they work great, and sometimes you may find them, +well, a pain to use. So, what's wrong in the latter case? + +When you write a test without using mocks, you exercise the code and assert that +it returns the correct value or that the system is in an expected state. This is +sometimes called "state-based testing". + +Mocks are great for what some call "interaction-based" testing: instead of +checking the system state at the very end, mock objects verify that they are +invoked the right way and report an error as soon as it arises, giving you a +handle on the precise context in which the error was triggered. This is often +more effective and economical to do than state-based testing. + +If you are doing state-based testing and using a test double just to simulate +the real object, you are probably better off using a fake. Using a mock in this +case causes pain, as it's not a strong point for mocks to perform complex +actions. If you experience this and think that mocks suck, you are just not +using the right tool for your problem. Or, you might be trying to solve the +wrong problem. :-) + +### I got a warning "Uninteresting function call encountered - default action taken.." Should I panic? + +By all means, NO! It's just an FYI. :-) + +What it means is that you have a mock function, you haven't set any expectations +on it (by gMock's rule this means that you are not interested in calls to this +function and therefore it can be called any number of times), and it is called. +That's OK - you didn't say it's not OK to call the function! + +What if you actually meant to disallow this function to be called, but forgot to +write `EXPECT_CALL(foo, Bar()).Times(0)`? While one can argue that it's the +user's fault, gMock tries to be nice and prints you a note. + +So, when you see the message and believe that there shouldn't be any +uninteresting calls, you should investigate what's going on. To make your life +easier, gMock dumps the stack trace when an uninteresting call is encountered. +From that you can figure out which mock function it is, and how it is called. + +### I want to define a custom action. Should I use Invoke() or implement the ActionInterface interface? + +Either way is fine - you want to choose the one that's more convenient for your +circumstance. + +Usually, if your action is for a particular function type, defining it using +`Invoke()` should be easier; if your action can be used in functions of +different types (e.g. if you are defining `Return(*value*)`), +`MakePolymorphicAction()` is easiest. Sometimes you want precise control on what +types of functions the action can be used in, and implementing `ActionInterface` +is the way to go here. See the implementation of `Return()` in `gmock-actions.h` +for an example. + +### I use SetArgPointee() in WillOnce(), but gcc complains about "conflicting return type specified". What does it mean? + +You got this error as gMock has no idea what value it should return when the +mock method is called. `SetArgPointee()` says what the side effect is, but +doesn't say what the return value should be. You need `DoAll()` to chain a +`SetArgPointee()` with a `Return()` that provides a value appropriate to the API +being mocked. + +See this [recipe](gmock_cook_book.md#mocking-side-effects) for more details and +an example. + +### I have a huge mock class, and Microsoft Visual C++ runs out of memory when compiling it. What can I do? + +We've noticed that when the `/clr` compiler flag is used, Visual C++ uses 5~6 +times as much memory when compiling a mock class. We suggest to avoid `/clr` +when compiling native C++ mocks. diff --git a/3rdparty/googletest-1.13.0/docs/gmock_for_dummies.md b/3rdparty/googletest-1.13.0/docs/gmock_for_dummies.md new file mode 100644 index 0000000000000000000000000000000000000000..b7264d3587f71ada659741bd6c47ac015ff46e99 --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/gmock_for_dummies.md @@ -0,0 +1,700 @@ +# gMock for Dummies + +## What Is gMock? + +When you write a prototype or test, often it's not feasible or wise to rely on +real objects entirely. A **mock object** implements the same interface as a real +object (so it can be used as one), but lets you specify at run time how it will +be used and what it should do (which methods will be called? in which order? how +many times? with what arguments? what will they return? etc). + +It is easy to confuse the term *fake objects* with mock objects. Fakes and mocks +actually mean very different things in the Test-Driven Development (TDD) +community: + +* **Fake** objects have working implementations, but usually take some + shortcut (perhaps to make the operations less expensive), which makes them + not suitable for production. An in-memory file system would be an example of + a fake. +* **Mocks** are objects pre-programmed with *expectations*, which form a + specification of the calls they are expected to receive. + +If all this seems too abstract for you, don't worry - the most important thing +to remember is that a mock allows you to check the *interaction* between itself +and code that uses it. The difference between fakes and mocks shall become much +clearer once you start to use mocks. + +**gMock** is a library (sometimes we also call it a "framework" to make it sound +cool) for creating mock classes and using them. It does to C++ what +jMock/EasyMock does to Java (well, more or less). + +When using gMock, + +1. first, you use some simple macros to describe the interface you want to + mock, and they will expand to the implementation of your mock class; +2. next, you create some mock objects and specify its expectations and behavior + using an intuitive syntax; +3. then you exercise code that uses the mock objects. gMock will catch any + violation to the expectations as soon as it arises. + +## Why gMock? + +While mock objects help you remove unnecessary dependencies in tests and make +them fast and reliable, using mocks manually in C++ is *hard*: + +* Someone has to implement the mocks. The job is usually tedious and + error-prone. No wonder people go great distance to avoid it. +* The quality of those manually written mocks is a bit, uh, unpredictable. You + may see some really polished ones, but you may also see some that were + hacked up in a hurry and have all sorts of ad hoc restrictions. +* The knowledge you gained from using one mock doesn't transfer to the next + one. + +In contrast, Java and Python programmers have some fine mock frameworks (jMock, +EasyMock, etc), which automate the creation of mocks. As a result, mocking is a +proven effective technique and widely adopted practice in those communities. +Having the right tool absolutely makes the difference. + +gMock was built to help C++ programmers. It was inspired by jMock and EasyMock, +but designed with C++'s specifics in mind. It is your friend if any of the +following problems is bothering you: + +* You are stuck with a sub-optimal design and wish you had done more + prototyping before it was too late, but prototyping in C++ is by no means + "rapid". +* Your tests are slow as they depend on too many libraries or use expensive + resources (e.g. a database). +* Your tests are brittle as some resources they use are unreliable (e.g. the + network). +* You want to test how your code handles a failure (e.g. a file checksum + error), but it's not easy to cause one. +* You need to make sure that your module interacts with other modules in the + right way, but it's hard to observe the interaction; therefore you resort to + observing the side effects at the end of the action, but it's awkward at + best. +* You want to "mock out" your dependencies, except that they don't have mock + implementations yet; and, frankly, you aren't thrilled by some of those + hand-written mocks. + +We encourage you to use gMock as + +* a *design* tool, for it lets you experiment with your interface design early + and often. More iterations lead to better designs! +* a *testing* tool to cut your tests' outbound dependencies and probe the + interaction between your module and its collaborators. + +## Getting Started + +gMock is bundled with googletest. + +## A Case for Mock Turtles + +Let's look at an example. Suppose you are developing a graphics program that +relies on a [LOGO](http://en.wikipedia.org/wiki/Logo_programming_language)-like +API for drawing. How would you test that it does the right thing? Well, you can +run it and compare the screen with a golden screen snapshot, but let's admit it: +tests like this are expensive to run and fragile (What if you just upgraded to a +shiny new graphics card that has better anti-aliasing? Suddenly you have to +update all your golden images.). It would be too painful if all your tests are +like this. Fortunately, you learned about +[Dependency Injection](http://en.wikipedia.org/wiki/Dependency_injection) and know the right thing +to do: instead of having your application talk to the system API directly, wrap +the API in an interface (say, `Turtle`) and code to that interface: + +```cpp +class Turtle { + ... + virtual ~Turtle() {} + virtual void PenUp() = 0; + virtual void PenDown() = 0; + virtual void Forward(int distance) = 0; + virtual void Turn(int degrees) = 0; + virtual void GoTo(int x, int y) = 0; + virtual int GetX() const = 0; + virtual int GetY() const = 0; +}; +``` + +(Note that the destructor of `Turtle` **must** be virtual, as is the case for +**all** classes you intend to inherit from - otherwise the destructor of the +derived class will not be called when you delete an object through a base +pointer, and you'll get corrupted program states like memory leaks.) + +You can control whether the turtle's movement will leave a trace using `PenUp()` +and `PenDown()`, and control its movement using `Forward()`, `Turn()`, and +`GoTo()`. Finally, `GetX()` and `GetY()` tell you the current position of the +turtle. + +Your program will normally use a real implementation of this interface. In +tests, you can use a mock implementation instead. This allows you to easily +check what drawing primitives your program is calling, with what arguments, and +in which order. Tests written this way are much more robust (they won't break +because your new machine does anti-aliasing differently), easier to read and +maintain (the intent of a test is expressed in the code, not in some binary +images), and run *much, much faster*. + +## Writing the Mock Class + +If you are lucky, the mocks you need to use have already been implemented by +some nice people. If, however, you find yourself in the position to write a mock +class, relax - gMock turns this task into a fun game! (Well, almost.) + +### How to Define It + +Using the `Turtle` interface as example, here are the simple steps you need to +follow: + +* Derive a class `MockTurtle` from `Turtle`. +* Take a *virtual* function of `Turtle` (while it's possible to + [mock non-virtual methods using templates](gmock_cook_book.md#MockingNonVirtualMethods), + it's much more involved). +* In the `public:` section of the child class, write `MOCK_METHOD();` +* Now comes the fun part: you take the function signature, cut-and-paste it + into the macro, and add two commas - one between the return type and the + name, another between the name and the argument list. +* If you're mocking a const method, add a 4th parameter containing `(const)` + (the parentheses are required). +* Since you're overriding a virtual method, we suggest adding the `override` + keyword. For const methods the 4th parameter becomes `(const, override)`, + for non-const methods just `(override)`. This isn't mandatory. +* Repeat until all virtual functions you want to mock are done. (It goes + without saying that *all* pure virtual methods in your abstract class must + be either mocked or overridden.) + +After the process, you should have something like: + +```cpp +#include "gmock/gmock.h" // Brings in gMock. + +class MockTurtle : public Turtle { + public: + ... + MOCK_METHOD(void, PenUp, (), (override)); + MOCK_METHOD(void, PenDown, (), (override)); + MOCK_METHOD(void, Forward, (int distance), (override)); + MOCK_METHOD(void, Turn, (int degrees), (override)); + MOCK_METHOD(void, GoTo, (int x, int y), (override)); + MOCK_METHOD(int, GetX, (), (const, override)); + MOCK_METHOD(int, GetY, (), (const, override)); +}; +``` + +You don't need to define these mock methods somewhere else - the `MOCK_METHOD` +macro will generate the definitions for you. It's that simple! + +### Where to Put It + +When you define a mock class, you need to decide where to put its definition. +Some people put it in a `_test.cc`. This is fine when the interface being mocked +(say, `Foo`) is owned by the same person or team. Otherwise, when the owner of +`Foo` changes it, your test could break. (You can't really expect `Foo`'s +maintainer to fix every test that uses `Foo`, can you?) + +Generally, you should not mock classes you don't own. If you must mock such a +class owned by others, define the mock class in `Foo`'s Bazel package (usually +the same directory or a `testing` sub-directory), and put it in a `.h` and a +`cc_library` with `testonly=True`. Then everyone can reference them from their +tests. If `Foo` ever changes, there is only one copy of `MockFoo` to change, and +only tests that depend on the changed methods need to be fixed. + +Another way to do it: you can introduce a thin layer `FooAdaptor` on top of +`Foo` and code to this new interface. Since you own `FooAdaptor`, you can absorb +changes in `Foo` much more easily. While this is more work initially, carefully +choosing the adaptor interface can make your code easier to write and more +readable (a net win in the long run), as you can choose `FooAdaptor` to fit your +specific domain much better than `Foo` does. + +## Using Mocks in Tests + +Once you have a mock class, using it is easy. The typical work flow is: + +1. Import the gMock names from the `testing` namespace such that you can use + them unqualified (You only have to do it once per file). Remember that + namespaces are a good idea. +2. Create some mock objects. +3. Specify your expectations on them (How many times will a method be called? + With what arguments? What should it do? etc.). +4. Exercise some code that uses the mocks; optionally, check the result using + googletest assertions. If a mock method is called more than expected or with + wrong arguments, you'll get an error immediately. +5. When a mock is destructed, gMock will automatically check whether all + expectations on it have been satisfied. + +Here's an example: + +```cpp +#include "path/to/mock-turtle.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using ::testing::AtLeast; // #1 + +TEST(PainterTest, CanDrawSomething) { + MockTurtle turtle; // #2 + EXPECT_CALL(turtle, PenDown()) // #3 + .Times(AtLeast(1)); + + Painter painter(&turtle); // #4 + + EXPECT_TRUE(painter.DrawCircle(0, 0, 10)); // #5 +} +``` + +As you might have guessed, this test checks that `PenDown()` is called at least +once. If the `painter` object didn't call this method, your test will fail with +a message like this: + +```text +path/to/my_test.cc:119: Failure +Actual function call count doesn't match this expectation: +Actually: never called; +Expected: called at least once. +Stack trace: +... +``` + +**Tip 1:** If you run the test from an Emacs buffer, you can hit `` on +the line number to jump right to the failed expectation. + +**Tip 2:** If your mock objects are never deleted, the final verification won't +happen. Therefore it's a good idea to turn on the heap checker in your tests +when you allocate mocks on the heap. You get that automatically if you use the +`gtest_main` library already. + +**Important note:** gMock requires expectations to be set **before** the mock +functions are called, otherwise the behavior is **undefined**. Do not alternate +between calls to `EXPECT_CALL()` and calls to the mock functions, and do not set +any expectations on a mock after passing the mock to an API. + +This means `EXPECT_CALL()` should be read as expecting that a call will occur +*in the future*, not that a call has occurred. Why does gMock work like that? +Well, specifying the expectation beforehand allows gMock to report a violation +as soon as it rises, when the context (stack trace, etc) is still available. +This makes debugging much easier. + +Admittedly, this test is contrived and doesn't do much. You can easily achieve +the same effect without using gMock. However, as we shall reveal soon, gMock +allows you to do *so much more* with the mocks. + +## Setting Expectations + +The key to using a mock object successfully is to set the *right expectations* +on it. If you set the expectations too strict, your test will fail as the result +of unrelated changes. If you set them too loose, bugs can slip through. You want +to do it just right such that your test can catch exactly the kind of bugs you +intend it to catch. gMock provides the necessary means for you to do it "just +right." + +### General Syntax + +In gMock we use the `EXPECT_CALL()` macro to set an expectation on a mock +method. The general syntax is: + +```cpp +EXPECT_CALL(mock_object, method(matchers)) + .Times(cardinality) + .WillOnce(action) + .WillRepeatedly(action); +``` + +The macro has two arguments: first the mock object, and then the method and its +arguments. Note that the two are separated by a comma (`,`), not a period (`.`). +(Why using a comma? The answer is that it was necessary for technical reasons.) +If the method is not overloaded, the macro can also be called without matchers: + +```cpp +EXPECT_CALL(mock_object, non-overloaded-method) + .Times(cardinality) + .WillOnce(action) + .WillRepeatedly(action); +``` + +This syntax allows the test writer to specify "called with any arguments" +without explicitly specifying the number or types of arguments. To avoid +unintended ambiguity, this syntax may only be used for methods that are not +overloaded. + +Either form of the macro can be followed by some optional *clauses* that provide +more information about the expectation. We'll discuss how each clause works in +the coming sections. + +This syntax is designed to make an expectation read like English. For example, +you can probably guess that + +```cpp +using ::testing::Return; +... +EXPECT_CALL(turtle, GetX()) + .Times(5) + .WillOnce(Return(100)) + .WillOnce(Return(150)) + .WillRepeatedly(Return(200)); +``` + +says that the `turtle` object's `GetX()` method will be called five times, it +will return 100 the first time, 150 the second time, and then 200 every time. +Some people like to call this style of syntax a Domain-Specific Language (DSL). + +{: .callout .note} +**Note:** Why do we use a macro to do this? Well it serves two purposes: first +it makes expectations easily identifiable (either by `grep` or by a human +reader), and second it allows gMock to include the source file location of a +failed expectation in messages, making debugging easier. + +### Matchers: What Arguments Do We Expect? + +When a mock function takes arguments, we may specify what arguments we are +expecting, for example: + +```cpp +// Expects the turtle to move forward by 100 units. +EXPECT_CALL(turtle, Forward(100)); +``` + +Oftentimes you do not want to be too specific. Remember that talk about tests +being too rigid? Over specification leads to brittle tests and obscures the +intent of tests. Therefore we encourage you to specify only what's necessary—no +more, no less. If you aren't interested in the value of an argument, write `_` +as the argument, which means "anything goes": + +```cpp +using ::testing::_; +... +// Expects that the turtle jumps to somewhere on the x=50 line. +EXPECT_CALL(turtle, GoTo(50, _)); +``` + +`_` is an instance of what we call **matchers**. A matcher is like a predicate +and can test whether an argument is what we'd expect. You can use a matcher +inside `EXPECT_CALL()` wherever a function argument is expected. `_` is a +convenient way of saying "any value". + +In the above examples, `100` and `50` are also matchers; implicitly, they are +the same as `Eq(100)` and `Eq(50)`, which specify that the argument must be +equal (using `operator==`) to the matcher argument. There are many +[built-in matchers](reference/matchers.md) for common types (as well as +[custom matchers](gmock_cook_book.md#NewMatchers)); for example: + +```cpp +using ::testing::Ge; +... +// Expects the turtle moves forward by at least 100. +EXPECT_CALL(turtle, Forward(Ge(100))); +``` + +If you don't care about *any* arguments, rather than specify `_` for each of +them you may instead omit the parameter list: + +```cpp +// Expects the turtle to move forward. +EXPECT_CALL(turtle, Forward); +// Expects the turtle to jump somewhere. +EXPECT_CALL(turtle, GoTo); +``` + +This works for all non-overloaded methods; if a method is overloaded, you need +to help gMock resolve which overload is expected by specifying the number of +arguments and possibly also the +[types of the arguments](gmock_cook_book.md#SelectOverload). + +### Cardinalities: How Many Times Will It Be Called? + +The first clause we can specify following an `EXPECT_CALL()` is `Times()`. We +call its argument a **cardinality** as it tells *how many times* the call should +occur. It allows us to repeat an expectation many times without actually writing +it as many times. More importantly, a cardinality can be "fuzzy", just like a +matcher can be. This allows a user to express the intent of a test exactly. + +An interesting special case is when we say `Times(0)`. You may have guessed - it +means that the function shouldn't be called with the given arguments at all, and +gMock will report a googletest failure whenever the function is (wrongfully) +called. + +We've seen `AtLeast(n)` as an example of fuzzy cardinalities earlier. For the +list of built-in cardinalities you can use, see +[here](gmock_cheat_sheet.md#CardinalityList). + +The `Times()` clause can be omitted. **If you omit `Times()`, gMock will infer +the cardinality for you.** The rules are easy to remember: + +* If **neither** `WillOnce()` **nor** `WillRepeatedly()` is in the + `EXPECT_CALL()`, the inferred cardinality is `Times(1)`. +* If there are *n* `WillOnce()`'s but **no** `WillRepeatedly()`, where *n* >= + 1, the cardinality is `Times(n)`. +* If there are *n* `WillOnce()`'s and **one** `WillRepeatedly()`, where *n* >= + 0, the cardinality is `Times(AtLeast(n))`. + +**Quick quiz:** what do you think will happen if a function is expected to be +called twice but actually called four times? + +### Actions: What Should It Do? + +Remember that a mock object doesn't really have a working implementation? We as +users have to tell it what to do when a method is invoked. This is easy in +gMock. + +First, if the return type of a mock function is a built-in type or a pointer, +the function has a **default action** (a `void` function will just return, a +`bool` function will return `false`, and other functions will return 0). In +addition, in C++ 11 and above, a mock function whose return type is +default-constructible (i.e. has a default constructor) has a default action of +returning a default-constructed value. If you don't say anything, this behavior +will be used. + +Second, if a mock function doesn't have a default action, or the default action +doesn't suit you, you can specify the action to be taken each time the +expectation matches using a series of `WillOnce()` clauses followed by an +optional `WillRepeatedly()`. For example, + +```cpp +using ::testing::Return; +... +EXPECT_CALL(turtle, GetX()) + .WillOnce(Return(100)) + .WillOnce(Return(200)) + .WillOnce(Return(300)); +``` + +says that `turtle.GetX()` will be called *exactly three times* (gMock inferred +this from how many `WillOnce()` clauses we've written, since we didn't +explicitly write `Times()`), and will return 100, 200, and 300 respectively. + +```cpp +using ::testing::Return; +... +EXPECT_CALL(turtle, GetY()) + .WillOnce(Return(100)) + .WillOnce(Return(200)) + .WillRepeatedly(Return(300)); +``` + +says that `turtle.GetY()` will be called *at least twice* (gMock knows this as +we've written two `WillOnce()` clauses and a `WillRepeatedly()` while having no +explicit `Times()`), will return 100 and 200 respectively the first two times, +and 300 from the third time on. + +Of course, if you explicitly write a `Times()`, gMock will not try to infer the +cardinality itself. What if the number you specified is larger than there are +`WillOnce()` clauses? Well, after all `WillOnce()`s are used up, gMock will do +the *default* action for the function every time (unless, of course, you have a +`WillRepeatedly()`.). + +What can we do inside `WillOnce()` besides `Return()`? You can return a +reference using `ReturnRef(`*`variable`*`)`, or invoke a pre-defined function, +among [others](gmock_cook_book.md#using-actions). + +**Important note:** The `EXPECT_CALL()` statement evaluates the action clause +only once, even though the action may be performed many times. Therefore you +must be careful about side effects. The following may not do what you want: + +```cpp +using ::testing::Return; +... +int n = 100; +EXPECT_CALL(turtle, GetX()) + .Times(4) + .WillRepeatedly(Return(n++)); +``` + +Instead of returning 100, 101, 102, ..., consecutively, this mock function will +always return 100 as `n++` is only evaluated once. Similarly, `Return(new Foo)` +will create a new `Foo` object when the `EXPECT_CALL()` is executed, and will +return the same pointer every time. If you want the side effect to happen every +time, you need to define a custom action, which we'll teach in the +[cook book](gmock_cook_book.md). + +Time for another quiz! What do you think the following means? + +```cpp +using ::testing::Return; +... +EXPECT_CALL(turtle, GetY()) + .Times(4) + .WillOnce(Return(100)); +``` + +Obviously `turtle.GetY()` is expected to be called four times. But if you think +it will return 100 every time, think twice! Remember that one `WillOnce()` +clause will be consumed each time the function is invoked and the default action +will be taken afterwards. So the right answer is that `turtle.GetY()` will +return 100 the first time, but **return 0 from the second time on**, as +returning 0 is the default action for `int` functions. + +### Using Multiple Expectations {#MultiExpectations} + +So far we've only shown examples where you have a single expectation. More +realistically, you'll specify expectations on multiple mock methods which may be +from multiple mock objects. + +By default, when a mock method is invoked, gMock will search the expectations in +the **reverse order** they are defined, and stop when an active expectation that +matches the arguments is found (you can think of it as "newer rules override +older ones."). If the matching expectation cannot take any more calls, you will +get an upper-bound-violated failure. Here's an example: + +```cpp +using ::testing::_; +... +EXPECT_CALL(turtle, Forward(_)); // #1 +EXPECT_CALL(turtle, Forward(10)) // #2 + .Times(2); +``` + +If `Forward(10)` is called three times in a row, the third time it will be an +error, as the last matching expectation (#2) has been saturated. If, however, +the third `Forward(10)` call is replaced by `Forward(20)`, then it would be OK, +as now #1 will be the matching expectation. + +{: .callout .note} +**Note:** Why does gMock search for a match in the *reverse* order of the +expectations? The reason is that this allows a user to set up the default +expectations in a mock object's constructor or the test fixture's set-up phase +and then customize the mock by writing more specific expectations in the test +body. So, if you have two expectations on the same method, you want to put the +one with more specific matchers **after** the other, or the more specific rule +would be shadowed by the more general one that comes after it. + +{: .callout .tip} +**Tip:** It is very common to start with a catch-all expectation for a method +and `Times(AnyNumber())` (omitting arguments, or with `_` for all arguments, if +overloaded). This makes any calls to the method expected. This is not necessary +for methods that are not mentioned at all (these are "uninteresting"), but is +useful for methods that have some expectations, but for which other calls are +ok. See +[Understanding Uninteresting vs Unexpected Calls](gmock_cook_book.md#uninteresting-vs-unexpected). + +### Ordered vs Unordered Calls {#OrderedCalls} + +By default, an expectation can match a call even though an earlier expectation +hasn't been satisfied. In other words, the calls don't have to occur in the +order the expectations are specified. + +Sometimes, you may want all the expected calls to occur in a strict order. To +say this in gMock is easy: + +```cpp +using ::testing::InSequence; +... +TEST(FooTest, DrawsLineSegment) { + ... + { + InSequence seq; + + EXPECT_CALL(turtle, PenDown()); + EXPECT_CALL(turtle, Forward(100)); + EXPECT_CALL(turtle, PenUp()); + } + Foo(); +} +``` + +By creating an object of type `InSequence`, all expectations in its scope are +put into a *sequence* and have to occur *sequentially*. Since we are just +relying on the constructor and destructor of this object to do the actual work, +its name is really irrelevant. + +In this example, we test that `Foo()` calls the three expected functions in the +order as written. If a call is made out-of-order, it will be an error. + +(What if you care about the relative order of some of the calls, but not all of +them? Can you specify an arbitrary partial order? The answer is ... yes! The +details can be found [here](gmock_cook_book.md#OrderedCalls).) + +### All Expectations Are Sticky (Unless Said Otherwise) {#StickyExpectations} + +Now let's do a quick quiz to see how well you can use this mock stuff already. +How would you test that the turtle is asked to go to the origin *exactly twice* +(you want to ignore any other instructions it receives)? + +After you've come up with your answer, take a look at ours and compare notes +(solve it yourself first - don't cheat!): + +```cpp +using ::testing::_; +using ::testing::AnyNumber; +... +EXPECT_CALL(turtle, GoTo(_, _)) // #1 + .Times(AnyNumber()); +EXPECT_CALL(turtle, GoTo(0, 0)) // #2 + .Times(2); +``` + +Suppose `turtle.GoTo(0, 0)` is called three times. In the third time, gMock will +see that the arguments match expectation #2 (remember that we always pick the +last matching expectation). Now, since we said that there should be only two +such calls, gMock will report an error immediately. This is basically what we've +told you in the [Using Multiple Expectations](#MultiExpectations) section above. + +This example shows that **expectations in gMock are "sticky" by default**, in +the sense that they remain active even after we have reached their invocation +upper bounds. This is an important rule to remember, as it affects the meaning +of the spec, and is **different** to how it's done in many other mocking +frameworks (Why'd we do that? Because we think our rule makes the common cases +easier to express and understand.). + +Simple? Let's see if you've really understood it: what does the following code +say? + +```cpp +using ::testing::Return; +... +for (int i = n; i > 0; i--) { + EXPECT_CALL(turtle, GetX()) + .WillOnce(Return(10*i)); +} +``` + +If you think it says that `turtle.GetX()` will be called `n` times and will +return 10, 20, 30, ..., consecutively, think twice! The problem is that, as we +said, expectations are sticky. So, the second time `turtle.GetX()` is called, +the last (latest) `EXPECT_CALL()` statement will match, and will immediately +lead to an "upper bound violated" error - this piece of code is not very useful! + +One correct way of saying that `turtle.GetX()` will return 10, 20, 30, ..., is +to explicitly say that the expectations are *not* sticky. In other words, they +should *retire* as soon as they are saturated: + +```cpp +using ::testing::Return; +... +for (int i = n; i > 0; i--) { + EXPECT_CALL(turtle, GetX()) + .WillOnce(Return(10*i)) + .RetiresOnSaturation(); +} +``` + +And, there's a better way to do it: in this case, we expect the calls to occur +in a specific order, and we line up the actions to match the order. Since the +order is important here, we should make it explicit using a sequence: + +```cpp +using ::testing::InSequence; +using ::testing::Return; +... +{ + InSequence s; + + for (int i = 1; i <= n; i++) { + EXPECT_CALL(turtle, GetX()) + .WillOnce(Return(10*i)) + .RetiresOnSaturation(); + } +} +``` + +By the way, the other situation where an expectation may *not* be sticky is when +it's in a sequence - as soon as another expectation that comes after it in the +sequence has been used, it automatically retires (and will never be used to +match any call). + +### Uninteresting Calls + +A mock object may have many methods, and not all of them are that interesting. +For example, in some tests we may not care about how many times `GetX()` and +`GetY()` get called. + +In gMock, if you are not interested in a method, just don't say anything about +it. If a call to this method occurs, you'll see a warning in the test output, +but it won't be a failure. This is called "naggy" behavior; to change, see +[The Nice, the Strict, and the Naggy](gmock_cook_book.md#NiceStrictNaggy). diff --git a/3rdparty/googletest-1.13.0/docs/index.md b/3rdparty/googletest-1.13.0/docs/index.md new file mode 100644 index 0000000000000000000000000000000000000000..b162c740116394bd6871fe9e65f78cd0289b258f --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/index.md @@ -0,0 +1,22 @@ +# GoogleTest User's Guide + +## Welcome to GoogleTest! + +GoogleTest is Google's C++ testing and mocking framework. This user's guide has +the following contents: + +* [GoogleTest Primer](primer.md) - Teaches you how to write simple tests using + GoogleTest. Read this first if you are new to GoogleTest. +* [GoogleTest Advanced](advanced.md) - Read this when you've finished the + Primer and want to utilize GoogleTest to its full potential. +* [GoogleTest Samples](samples.md) - Describes some GoogleTest samples. +* [GoogleTest FAQ](faq.md) - Have a question? Want some tips? Check here + first. +* [Mocking for Dummies](gmock_for_dummies.md) - Teaches you how to create mock + objects and use them in tests. +* [Mocking Cookbook](gmock_cook_book.md) - Includes tips and approaches to + common mocking use cases. +* [Mocking Cheat Sheet](gmock_cheat_sheet.md) - A handy reference for + matchers, actions, invariants, and more. +* [Mocking FAQ](gmock_faq.md) - Contains answers to some mocking-specific + questions. diff --git a/3rdparty/googletest-1.13.0/docs/pkgconfig.md b/3rdparty/googletest-1.13.0/docs/pkgconfig.md new file mode 100644 index 0000000000000000000000000000000000000000..18a2546a3846acde26b930a5ee30a00cce96a570 --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/pkgconfig.md @@ -0,0 +1,148 @@ +## Using GoogleTest from various build systems + +GoogleTest comes with pkg-config files that can be used to determine all +necessary flags for compiling and linking to GoogleTest (and GoogleMock). +Pkg-config is a standardised plain-text format containing + +* the includedir (-I) path +* necessary macro (-D) definitions +* further required flags (-pthread) +* the library (-L) path +* the library (-l) to link to + +All current build systems support pkg-config in one way or another. For all +examples here we assume you want to compile the sample +`samples/sample3_unittest.cc`. + +### CMake + +Using `pkg-config` in CMake is fairly easy: + +```cmake +cmake_minimum_required(VERSION 3.0) + +cmake_policy(SET CMP0048 NEW) +project(my_gtest_pkgconfig VERSION 0.0.1 LANGUAGES CXX) + +find_package(PkgConfig) +pkg_search_module(GTEST REQUIRED gtest_main) + +add_executable(testapp samples/sample3_unittest.cc) +target_link_libraries(testapp ${GTEST_LDFLAGS}) +target_compile_options(testapp PUBLIC ${GTEST_CFLAGS}) + +include(CTest) +add_test(first_and_only_test testapp) +``` + +It is generally recommended that you use `target_compile_options` + `_CFLAGS` +over `target_include_directories` + `_INCLUDE_DIRS` as the former includes not +just -I flags (GoogleTest might require a macro indicating to internal headers +that all libraries have been compiled with threading enabled. In addition, +GoogleTest might also require `-pthread` in the compiling step, and as such +splitting the pkg-config `Cflags` variable into include dirs and macros for +`target_compile_definitions()` might still miss this). The same recommendation +goes for using `_LDFLAGS` over the more commonplace `_LIBRARIES`, which happens +to discard `-L` flags and `-pthread`. + +### Help! pkg-config can't find GoogleTest! + +Let's say you have a `CMakeLists.txt` along the lines of the one in this +tutorial and you try to run `cmake`. It is very possible that you get a failure +along the lines of: + +``` +-- Checking for one of the modules 'gtest_main' +CMake Error at /usr/share/cmake/Modules/FindPkgConfig.cmake:640 (message): + None of the required 'gtest_main' found +``` + +These failures are common if you installed GoogleTest yourself and have not +sourced it from a distro or other package manager. If so, you need to tell +pkg-config where it can find the `.pc` files containing the information. Say you +installed GoogleTest to `/usr/local`, then it might be that the `.pc` files are +installed under `/usr/local/lib64/pkgconfig`. If you set + +``` +export PKG_CONFIG_PATH=/usr/local/lib64/pkgconfig +``` + +pkg-config will also try to look in `PKG_CONFIG_PATH` to find `gtest_main.pc`. + +### Using pkg-config in a cross-compilation setting + +Pkg-config can be used in a cross-compilation setting too. To do this, let's +assume the final prefix of the cross-compiled installation will be `/usr`, and +your sysroot is `/home/MYUSER/sysroot`. Configure and install GTest using + +``` +mkdir build && cmake -DCMAKE_INSTALL_PREFIX=/usr .. +``` + +Install into the sysroot using `DESTDIR`: + +``` +make -j install DESTDIR=/home/MYUSER/sysroot +``` + +Before we continue, it is recommended to **always** define the following two +variables for pkg-config in a cross-compilation setting: + +``` +export PKG_CONFIG_ALLOW_SYSTEM_CFLAGS=yes +export PKG_CONFIG_ALLOW_SYSTEM_LIBS=yes +``` + +otherwise `pkg-config` will filter `-I` and `-L` flags against standard prefixes +such as `/usr` (see https://bugs.freedesktop.org/show_bug.cgi?id=28264#c3 for +reasons why this stripping needs to occur usually). + +If you look at the generated pkg-config file, it will look something like + +``` +libdir=/usr/lib64 +includedir=/usr/include + +Name: gtest +Description: GoogleTest (without main() function) +Version: 1.11.0 +URL: https://github.com/google/googletest +Libs: -L${libdir} -lgtest -lpthread +Cflags: -I${includedir} -DGTEST_HAS_PTHREAD=1 -lpthread +``` + +Notice that the sysroot is not included in `libdir` and `includedir`! If you try +to run `pkg-config` with the correct +`PKG_CONFIG_LIBDIR=/home/MYUSER/sysroot/usr/lib64/pkgconfig` against this `.pc` +file, you will get + +``` +$ pkg-config --cflags gtest +-DGTEST_HAS_PTHREAD=1 -lpthread -I/usr/include +$ pkg-config --libs gtest +-L/usr/lib64 -lgtest -lpthread +``` + +which is obviously wrong and points to the `CBUILD` and not `CHOST` root. In +order to use this in a cross-compilation setting, we need to tell pkg-config to +inject the actual sysroot into `-I` and `-L` variables. Let us now tell +pkg-config about the actual sysroot + +``` +export PKG_CONFIG_DIR= +export PKG_CONFIG_SYSROOT_DIR=/home/MYUSER/sysroot +export PKG_CONFIG_LIBDIR=${PKG_CONFIG_SYSROOT_DIR}/usr/lib64/pkgconfig +``` + +and running `pkg-config` again we get + +``` +$ pkg-config --cflags gtest +-DGTEST_HAS_PTHREAD=1 -lpthread -I/home/MYUSER/sysroot/usr/include +$ pkg-config --libs gtest +-L/home/MYUSER/sysroot/usr/lib64 -lgtest -lpthread +``` + +which contains the correct sysroot now. For a more comprehensive guide to also +including `${CHOST}` in build system calls, see the excellent tutorial by Diego +Elio Pettenò: diff --git a/3rdparty/googletest-1.13.0/docs/platforms.md b/3rdparty/googletest-1.13.0/docs/platforms.md new file mode 100644 index 0000000000000000000000000000000000000000..eba6ef805661f33dff7588039396678a19a108a9 --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/platforms.md @@ -0,0 +1,35 @@ +# Supported Platforms + +GoogleTest requires a codebase and compiler compliant with the C++11 standard or +newer. + +The GoogleTest code is officially supported on the following platforms. +Operating systems or tools not listed below are community-supported. For +community-supported platforms, patches that do not complicate the code may be +considered. + +If you notice any problems on your platform, please file an issue on the +[GoogleTest GitHub Issue Tracker](https://github.com/google/googletest/issues). +Pull requests containing fixes are welcome! + +### Operating systems + +* Linux +* macOS +* Windows + +### Compilers + +* gcc 5.0+ +* clang 5.0+ +* MSVC 2015+ + +**macOS users:** Xcode 9.3+ provides clang 5.0+. + +### Build systems + +* [Bazel](https://bazel.build/) +* [CMake](https://cmake.org/) + +Bazel is the build system used by the team internally and in tests. CMake is +supported on a best-effort basis and by the community. diff --git a/3rdparty/googletest-1.13.0/docs/primer.md b/3rdparty/googletest-1.13.0/docs/primer.md new file mode 100644 index 0000000000000000000000000000000000000000..2ffbf53bc8dff6337e8f4c4d33d6f7b4df767bbe --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/primer.md @@ -0,0 +1,483 @@ +# Googletest Primer + +## Introduction: Why googletest? + +*googletest* helps you write better C++ tests. + +googletest is a testing framework developed by the Testing Technology team with +Google's specific requirements and constraints in mind. Whether you work on +Linux, Windows, or a Mac, if you write C++ code, googletest can help you. And it +supports *any* kind of tests, not just unit tests. + +So what makes a good test, and how does googletest fit in? We believe: + +1. Tests should be *independent* and *repeatable*. It's a pain to debug a test + that succeeds or fails as a result of other tests. googletest isolates the + tests by running each of them on a different object. When a test fails, + googletest allows you to run it in isolation for quick debugging. +2. Tests should be well *organized* and reflect the structure of the tested + code. googletest groups related tests into test suites that can share data + and subroutines. This common pattern is easy to recognize and makes tests + easy to maintain. Such consistency is especially helpful when people switch + projects and start to work on a new code base. +3. Tests should be *portable* and *reusable*. Google has a lot of code that is + platform-neutral; its tests should also be platform-neutral. googletest + works on different OSes, with different compilers, with or without + exceptions, so googletest tests can work with a variety of configurations. +4. When tests fail, they should provide as much *information* about the problem + as possible. googletest doesn't stop at the first test failure. Instead, it + only stops the current test and continues with the next. You can also set up + tests that report non-fatal failures after which the current test continues. + Thus, you can detect and fix multiple bugs in a single run-edit-compile + cycle. +5. The testing framework should liberate test writers from housekeeping chores + and let them focus on the test *content*. googletest automatically keeps + track of all tests defined, and doesn't require the user to enumerate them + in order to run them. +6. Tests should be *fast*. With googletest, you can reuse shared resources + across tests and pay for the set-up/tear-down only once, without making + tests depend on each other. + +Since googletest is based on the popular xUnit architecture, you'll feel right +at home if you've used JUnit or PyUnit before. If not, it will take you about 10 +minutes to learn the basics and get started. So let's go! + +## Beware of the nomenclature + +{: .callout .note} +_Note:_ There might be some confusion arising from different definitions of the +terms _Test_, _Test Case_ and _Test Suite_, so beware of misunderstanding these. + +Historically, googletest started to use the term _Test Case_ for grouping +related tests, whereas current publications, including International Software +Testing Qualifications Board ([ISTQB](http://www.istqb.org/)) materials and +various textbooks on software quality, use the term +_[Test Suite][istqb test suite]_ for this. + +The related term _Test_, as it is used in googletest, corresponds to the term +_[Test Case][istqb test case]_ of ISTQB and others. + +The term _Test_ is commonly of broad enough sense, including ISTQB's definition +of _Test Case_, so it's not much of a problem here. But the term _Test Case_ as +was used in Google Test is of contradictory sense and thus confusing. + +googletest recently started replacing the term _Test Case_ with _Test Suite_. +The preferred API is *TestSuite*. The older TestCase API is being slowly +deprecated and refactored away. + +So please be aware of the different definitions of the terms: + + +Meaning | googletest Term | [ISTQB](http://www.istqb.org/) Term +:----------------------------------------------------------------------------------- | :---------------------- | :---------------------------------- +Exercise a particular program path with specific input values and verify the results | [TEST()](#simple-tests) | [Test Case][istqb test case] + + +[istqb test case]: http://glossary.istqb.org/en/search/test%20case +[istqb test suite]: http://glossary.istqb.org/en/search/test%20suite + +## Basic Concepts + +When using googletest, you start by writing *assertions*, which are statements +that check whether a condition is true. An assertion's result can be *success*, +*nonfatal failure*, or *fatal failure*. If a fatal failure occurs, it aborts the +current function; otherwise the program continues normally. + +*Tests* use assertions to verify the tested code's behavior. If a test crashes +or has a failed assertion, then it *fails*; otherwise it *succeeds*. + +A *test suite* contains one or many tests. You should group your tests into test +suites that reflect the structure of the tested code. When multiple tests in a +test suite need to share common objects and subroutines, you can put them into a +*test fixture* class. + +A *test program* can contain multiple test suites. + +We'll now explain how to write a test program, starting at the individual +assertion level and building up to tests and test suites. + +## Assertions + +googletest assertions are macros that resemble function calls. You test a class +or function by making assertions about its behavior. When an assertion fails, +googletest prints the assertion's source file and line number location, along +with a failure message. You may also supply a custom failure message which will +be appended to googletest's message. + +The assertions come in pairs that test the same thing but have different effects +on the current function. `ASSERT_*` versions generate fatal failures when they +fail, and **abort the current function**. `EXPECT_*` versions generate nonfatal +failures, which don't abort the current function. Usually `EXPECT_*` are +preferred, as they allow more than one failure to be reported in a test. +However, you should use `ASSERT_*` if it doesn't make sense to continue when the +assertion in question fails. + +Since a failed `ASSERT_*` returns from the current function immediately, +possibly skipping clean-up code that comes after it, it may cause a space leak. +Depending on the nature of the leak, it may or may not be worth fixing - so keep +this in mind if you get a heap checker error in addition to assertion errors. + +To provide a custom failure message, simply stream it into the macro using the +`<<` operator or a sequence of such operators. See the following example, using +the [`ASSERT_EQ` and `EXPECT_EQ`](reference/assertions.md#EXPECT_EQ) macros to +verify value equality: + +```c++ +ASSERT_EQ(x.size(), y.size()) << "Vectors x and y are of unequal length"; + +for (int i = 0; i < x.size(); ++i) { + EXPECT_EQ(x[i], y[i]) << "Vectors x and y differ at index " << i; +} +``` + +Anything that can be streamed to an `ostream` can be streamed to an assertion +macro--in particular, C strings and `string` objects. If a wide string +(`wchar_t*`, `TCHAR*` in `UNICODE` mode on Windows, or `std::wstring`) is +streamed to an assertion, it will be translated to UTF-8 when printed. + +GoogleTest provides a collection of assertions for verifying the behavior of +your code in various ways. You can check Boolean conditions, compare values +based on relational operators, verify string values, floating-point values, and +much more. There are even assertions that enable you to verify more complex +states by providing custom predicates. For the complete list of assertions +provided by GoogleTest, see the [Assertions Reference](reference/assertions.md). + +## Simple Tests + +To create a test: + +1. Use the `TEST()` macro to define and name a test function. These are + ordinary C++ functions that don't return a value. +2. In this function, along with any valid C++ statements you want to include, + use the various googletest assertions to check values. +3. The test's result is determined by the assertions; if any assertion in the + test fails (either fatally or non-fatally), or if the test crashes, the + entire test fails. Otherwise, it succeeds. + +```c++ +TEST(TestSuiteName, TestName) { + ... test body ... +} +``` + +`TEST()` arguments go from general to specific. The *first* argument is the name +of the test suite, and the *second* argument is the test's name within the test +suite. Both names must be valid C++ identifiers, and they should not contain any +underscores (`_`). A test's *full name* consists of its containing test suite +and its individual name. Tests from different test suites can have the same +individual name. + +For example, let's take a simple integer function: + +```c++ +int Factorial(int n); // Returns the factorial of n +``` + +A test suite for this function might look like: + +```c++ +// Tests factorial of 0. +TEST(FactorialTest, HandlesZeroInput) { + EXPECT_EQ(Factorial(0), 1); +} + +// Tests factorial of positive numbers. +TEST(FactorialTest, HandlesPositiveInput) { + EXPECT_EQ(Factorial(1), 1); + EXPECT_EQ(Factorial(2), 2); + EXPECT_EQ(Factorial(3), 6); + EXPECT_EQ(Factorial(8), 40320); +} +``` + +googletest groups the test results by test suites, so logically related tests +should be in the same test suite; in other words, the first argument to their +`TEST()` should be the same. In the above example, we have two tests, +`HandlesZeroInput` and `HandlesPositiveInput`, that belong to the same test +suite `FactorialTest`. + +When naming your test suites and tests, you should follow the same convention as +for +[naming functions and classes](https://google.github.io/styleguide/cppguide.html#Function_Names). + +**Availability**: Linux, Windows, Mac. + +## Test Fixtures: Using the Same Data Configuration for Multiple Tests {#same-data-multiple-tests} + +If you find yourself writing two or more tests that operate on similar data, you +can use a *test fixture*. This allows you to reuse the same configuration of +objects for several different tests. + +To create a fixture: + +1. Derive a class from `::testing::Test` . Start its body with `protected:`, as + we'll want to access fixture members from sub-classes. +2. Inside the class, declare any objects you plan to use. +3. If necessary, write a default constructor or `SetUp()` function to prepare + the objects for each test. A common mistake is to spell `SetUp()` as + **`Setup()`** with a small `u` - Use `override` in C++11 to make sure you + spelled it correctly. +4. If necessary, write a destructor or `TearDown()` function to release any + resources you allocated in `SetUp()` . To learn when you should use the + constructor/destructor and when you should use `SetUp()/TearDown()`, read + the [FAQ](faq.md#CtorVsSetUp). +5. If needed, define subroutines for your tests to share. + +When using a fixture, use `TEST_F()` instead of `TEST()` as it allows you to +access objects and subroutines in the test fixture: + +```c++ +TEST_F(TestFixtureName, TestName) { + ... test body ... +} +``` + +Like `TEST()`, the first argument is the test suite name, but for `TEST_F()` +this must be the name of the test fixture class. You've probably guessed: `_F` +is for fixture. + +Unfortunately, the C++ macro system does not allow us to create a single macro +that can handle both types of tests. Using the wrong macro causes a compiler +error. + +Also, you must first define a test fixture class before using it in a +`TEST_F()`, or you'll get the compiler error "`virtual outside class +declaration`". + +For each test defined with `TEST_F()`, googletest will create a *fresh* test +fixture at runtime, immediately initialize it via `SetUp()`, run the test, clean +up by calling `TearDown()`, and then delete the test fixture. Note that +different tests in the same test suite have different test fixture objects, and +googletest always deletes a test fixture before it creates the next one. +googletest does **not** reuse the same test fixture for multiple tests. Any +changes one test makes to the fixture do not affect other tests. + +As an example, let's write tests for a FIFO queue class named `Queue`, which has +the following interface: + +```c++ +template // E is the element type. +class Queue { + public: + Queue(); + void Enqueue(const E& element); + E* Dequeue(); // Returns NULL if the queue is empty. + size_t size() const; + ... +}; +``` + +First, define a fixture class. By convention, you should give it the name +`FooTest` where `Foo` is the class being tested. + +```c++ +class QueueTest : public ::testing::Test { + protected: + void SetUp() override { + // q0_ remains empty + q1_.Enqueue(1); + q2_.Enqueue(2); + q2_.Enqueue(3); + } + + // void TearDown() override {} + + Queue q0_; + Queue q1_; + Queue q2_; +}; +``` + +In this case, `TearDown()` is not needed since we don't have to clean up after +each test, other than what's already done by the destructor. + +Now we'll write tests using `TEST_F()` and this fixture. + +```c++ +TEST_F(QueueTest, IsEmptyInitially) { + EXPECT_EQ(q0_.size(), 0); +} + +TEST_F(QueueTest, DequeueWorks) { + int* n = q0_.Dequeue(); + EXPECT_EQ(n, nullptr); + + n = q1_.Dequeue(); + ASSERT_NE(n, nullptr); + EXPECT_EQ(*n, 1); + EXPECT_EQ(q1_.size(), 0); + delete n; + + n = q2_.Dequeue(); + ASSERT_NE(n, nullptr); + EXPECT_EQ(*n, 2); + EXPECT_EQ(q2_.size(), 1); + delete n; +} +``` + +The above uses both `ASSERT_*` and `EXPECT_*` assertions. The rule of thumb is +to use `EXPECT_*` when you want the test to continue to reveal more errors after +the assertion failure, and use `ASSERT_*` when continuing after failure doesn't +make sense. For example, the second assertion in the `Dequeue` test is +`ASSERT_NE(n, nullptr)`, as we need to dereference the pointer `n` later, which +would lead to a segfault when `n` is `NULL`. + +When these tests run, the following happens: + +1. googletest constructs a `QueueTest` object (let's call it `t1`). +2. `t1.SetUp()` initializes `t1`. +3. The first test (`IsEmptyInitially`) runs on `t1`. +4. `t1.TearDown()` cleans up after the test finishes. +5. `t1` is destructed. +6. The above steps are repeated on another `QueueTest` object, this time + running the `DequeueWorks` test. + +**Availability**: Linux, Windows, Mac. + +## Invoking the Tests + +`TEST()` and `TEST_F()` implicitly register their tests with googletest. So, +unlike with many other C++ testing frameworks, you don't have to re-list all +your defined tests in order to run them. + +After defining your tests, you can run them with `RUN_ALL_TESTS()`, which +returns `0` if all the tests are successful, or `1` otherwise. Note that +`RUN_ALL_TESTS()` runs *all tests* in your link unit--they can be from different +test suites, or even different source files. + +When invoked, the `RUN_ALL_TESTS()` macro: + +* Saves the state of all googletest flags. + +* Creates a test fixture object for the first test. + +* Initializes it via `SetUp()`. + +* Runs the test on the fixture object. + +* Cleans up the fixture via `TearDown()`. + +* Deletes the fixture. + +* Restores the state of all googletest flags. + +* Repeats the above steps for the next test, until all tests have run. + +If a fatal failure happens the subsequent steps will be skipped. + +{: .callout .important} +> IMPORTANT: You must **not** ignore the return value of `RUN_ALL_TESTS()`, or +> you will get a compiler error. The rationale for this design is that the +> automated testing service determines whether a test has passed based on its +> exit code, not on its stdout/stderr output; thus your `main()` function must +> return the value of `RUN_ALL_TESTS()`. +> +> Also, you should call `RUN_ALL_TESTS()` only **once**. Calling it more than +> once conflicts with some advanced googletest features (e.g., thread-safe +> [death tests](advanced.md#death-tests)) and thus is not supported. + +**Availability**: Linux, Windows, Mac. + +## Writing the main() Function + +Most users should _not_ need to write their own `main` function and instead link +with `gtest_main` (as opposed to with `gtest`), which defines a suitable entry +point. See the end of this section for details. The remainder of this section +should only apply when you need to do something custom before the tests run that +cannot be expressed within the framework of fixtures and test suites. + +If you write your own `main` function, it should return the value of +`RUN_ALL_TESTS()`. + +You can start from this boilerplate: + +```c++ +#include "this/package/foo.h" + +#include "gtest/gtest.h" + +namespace my { +namespace project { +namespace { + +// The fixture for testing class Foo. +class FooTest : public ::testing::Test { + protected: + // You can remove any or all of the following functions if their bodies would + // be empty. + + FooTest() { + // You can do set-up work for each test here. + } + + ~FooTest() override { + // You can do clean-up work that doesn't throw exceptions here. + } + + // If the constructor and destructor are not enough for setting up + // and cleaning up each test, you can define the following methods: + + void SetUp() override { + // Code here will be called immediately after the constructor (right + // before each test). + } + + void TearDown() override { + // Code here will be called immediately after each test (right + // before the destructor). + } + + // Class members declared here can be used by all tests in the test suite + // for Foo. +}; + +// Tests that the Foo::Bar() method does Abc. +TEST_F(FooTest, MethodBarDoesAbc) { + const std::string input_filepath = "this/package/testdata/myinputfile.dat"; + const std::string output_filepath = "this/package/testdata/myoutputfile.dat"; + Foo f; + EXPECT_EQ(f.Bar(input_filepath, output_filepath), 0); +} + +// Tests that Foo does Xyz. +TEST_F(FooTest, DoesXyz) { + // Exercises the Xyz feature of Foo. +} + +} // namespace +} // namespace project +} // namespace my + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} +``` + +The `::testing::InitGoogleTest()` function parses the command line for +googletest flags, and removes all recognized flags. This allows the user to +control a test program's behavior via various flags, which we'll cover in the +[AdvancedGuide](advanced.md). You **must** call this function before calling +`RUN_ALL_TESTS()`, or the flags won't be properly initialized. + +On Windows, `InitGoogleTest()` also works with wide strings, so it can be used +in programs compiled in `UNICODE` mode as well. + +But maybe you think that writing all those `main` functions is too much work? We +agree with you completely, and that's why Google Test provides a basic +implementation of main(). If it fits your needs, then just link your test with +the `gtest_main` library and you are good to go. + +{: .callout .note} +NOTE: `ParseGUnitFlags()` is deprecated in favor of `InitGoogleTest()`. + +## Known Limitations + +* Google Test is designed to be thread-safe. The implementation is thread-safe + on systems where the `pthreads` library is available. It is currently + _unsafe_ to use Google Test assertions from two threads concurrently on + other systems (e.g. Windows). In most tests this is not an issue as usually + the assertions are done in the main thread. If you want to help, you can + volunteer to implement the necessary synchronization primitives in + `gtest-port.h` for your platform. diff --git a/3rdparty/googletest-1.13.0/docs/quickstart-bazel.md b/3rdparty/googletest-1.13.0/docs/quickstart-bazel.md new file mode 100644 index 0000000000000000000000000000000000000000..15c27a22ed9c63eeb234e35db8f02bb63ba8c9b8 --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/quickstart-bazel.md @@ -0,0 +1,146 @@ +# Quickstart: Building with Bazel + +This tutorial aims to get you up and running with GoogleTest using the Bazel +build system. If you're using GoogleTest for the first time or need a refresher, +we recommend this tutorial as a starting point. + +## Prerequisites + +To complete this tutorial, you'll need: + +* A compatible operating system (e.g. Linux, macOS, Windows). +* A compatible C++ compiler that supports at least C++14. +* [Bazel](https://bazel.build/), the preferred build system used by the + GoogleTest team. + +See [Supported Platforms](platforms.md) for more information about platforms +compatible with GoogleTest. + +If you don't already have Bazel installed, see the +[Bazel installation guide](https://bazel.build/install). + +{: .callout .note} Note: The terminal commands in this tutorial show a Unix +shell prompt, but the commands work on the Windows command line as well. + +## Set up a Bazel workspace + +A +[Bazel workspace](https://docs.bazel.build/versions/main/build-ref.html#workspace) +is a directory on your filesystem that you use to manage source files for the +software you want to build. Each workspace directory has a text file named +`WORKSPACE` which may be empty, or may contain references to external +dependencies required to build the outputs. + +First, create a directory for your workspace: + +``` +$ mkdir my_workspace && cd my_workspace +``` + +Next, you’ll create the `WORKSPACE` file to specify dependencies. A common and +recommended way to depend on GoogleTest is to use a +[Bazel external dependency](https://docs.bazel.build/versions/main/external.html) +via the +[`http_archive` rule](https://docs.bazel.build/versions/main/repo/http.html#http_archive). +To do this, in the root directory of your workspace (`my_workspace/`), create a +file named `WORKSPACE` with the following contents: + +``` +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "com_google_googletest", + urls = ["https://github.com/google/googletest/archive/5ab508a01f9eb089207ee87fd547d290da39d015.zip"], + strip_prefix = "googletest-5ab508a01f9eb089207ee87fd547d290da39d015", +) +``` + +The above configuration declares a dependency on GoogleTest which is downloaded +as a ZIP archive from GitHub. In the above example, +`5ab508a01f9eb089207ee87fd547d290da39d015` is the Git commit hash of the +GoogleTest version to use; we recommend updating the hash often to point to the +latest version. Use a recent hash on the `main` branch. + +Now you're ready to build C++ code that uses GoogleTest. + +## Create and run a binary + +With your Bazel workspace set up, you can now use GoogleTest code within your +own project. + +As an example, create a file named `hello_test.cc` in your `my_workspace` +directory with the following contents: + +```cpp +#include + +// Demonstrate some basic assertions. +TEST(HelloTest, BasicAssertions) { + // Expect two strings not to be equal. + EXPECT_STRNE("hello", "world"); + // Expect equality. + EXPECT_EQ(7 * 6, 42); +} +``` + +GoogleTest provides [assertions](primer.md#assertions) that you use to test the +behavior of your code. The above sample includes the main GoogleTest header file +and demonstrates some basic assertions. + +To build the code, create a file named `BUILD` in the same directory with the +following contents: + +``` +cc_test( + name = "hello_test", + size = "small", + srcs = ["hello_test.cc"], + deps = ["@com_google_googletest//:gtest_main"], +) +``` + +This `cc_test` rule declares the C++ test binary you want to build, and links to +GoogleTest (`//:gtest_main`) using the prefix you specified in the `WORKSPACE` +file (`@com_google_googletest`). For more information about Bazel `BUILD` files, +see the +[Bazel C++ Tutorial](https://docs.bazel.build/versions/main/tutorial/cpp.html). + +Now you can build and run your test: + +
+my_workspace$ bazel test --test_output=all //:hello_test
+INFO: Analyzed target //:hello_test (26 packages loaded, 362 targets configured).
+INFO: Found 1 test target...
+INFO: From Testing //:hello_test:
+==================== Test output for //:hello_test:
+Running main() from gmock_main.cc
+[==========] Running 1 test from 1 test suite.
+[----------] Global test environment set-up.
+[----------] 1 test from HelloTest
+[ RUN      ] HelloTest.BasicAssertions
+[       OK ] HelloTest.BasicAssertions (0 ms)
+[----------] 1 test from HelloTest (0 ms total)
+
+[----------] Global test environment tear-down
+[==========] 1 test from 1 test suite ran. (0 ms total)
+[  PASSED  ] 1 test.
+================================================================================
+Target //:hello_test up-to-date:
+  bazel-bin/hello_test
+INFO: Elapsed time: 4.190s, Critical Path: 3.05s
+INFO: 27 processes: 8 internal, 19 linux-sandbox.
+INFO: Build completed successfully, 27 total actions
+//:hello_test                                                     PASSED in 0.1s
+
+INFO: Build completed successfully, 27 total actions
+
+ +Congratulations! You've successfully built and run a test binary using +GoogleTest. + +## Next steps + +* [Check out the Primer](primer.md) to start learning how to write simple + tests. +* [See the code samples](samples.md) for more examples showing how to use a + variety of GoogleTest features. diff --git a/3rdparty/googletest-1.13.0/docs/quickstart-cmake.md b/3rdparty/googletest-1.13.0/docs/quickstart-cmake.md new file mode 100644 index 0000000000000000000000000000000000000000..5abe50441294bd3183c3d1d9f1934f7fea03f88f --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/quickstart-cmake.md @@ -0,0 +1,156 @@ +# Quickstart: Building with CMake + +This tutorial aims to get you up and running with GoogleTest using CMake. If +you're using GoogleTest for the first time or need a refresher, we recommend +this tutorial as a starting point. If your project uses Bazel, see the +[Quickstart for Bazel](quickstart-bazel.md) instead. + +## Prerequisites + +To complete this tutorial, you'll need: + +* A compatible operating system (e.g. Linux, macOS, Windows). +* A compatible C++ compiler that supports at least C++14. +* [CMake](https://cmake.org/) and a compatible build tool for building the + project. + * Compatible build tools include + [Make](https://www.gnu.org/software/make/), + [Ninja](https://ninja-build.org/), and others - see + [CMake Generators](https://cmake.org/cmake/help/latest/manual/cmake-generators.7.html) + for more information. + +See [Supported Platforms](platforms.md) for more information about platforms +compatible with GoogleTest. + +If you don't already have CMake installed, see the +[CMake installation guide](https://cmake.org/install). + +{: .callout .note} +Note: The terminal commands in this tutorial show a Unix shell prompt, but the +commands work on the Windows command line as well. + +## Set up a project + +CMake uses a file named `CMakeLists.txt` to configure the build system for a +project. You'll use this file to set up your project and declare a dependency on +GoogleTest. + +First, create a directory for your project: + +``` +$ mkdir my_project && cd my_project +``` + +Next, you'll create the `CMakeLists.txt` file and declare a dependency on +GoogleTest. There are many ways to express dependencies in the CMake ecosystem; +in this quickstart, you'll use the +[`FetchContent` CMake module](https://cmake.org/cmake/help/latest/module/FetchContent.html). +To do this, in your project directory (`my_project`), create a file named +`CMakeLists.txt` with the following contents: + +```cmake +cmake_minimum_required(VERSION 3.14) +project(my_project) + +# GoogleTest requires at least C++14 +set(CMAKE_CXX_STANDARD 14) + +include(FetchContent) +FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip +) +# For Windows: Prevent overriding the parent project's compiler/linker settings +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(googletest) +``` + +The above configuration declares a dependency on GoogleTest which is downloaded +from GitHub. In the above example, `03597a01ee50ed33e9dfd640b249b4be3799d395` is +the Git commit hash of the GoogleTest version to use; we recommend updating the +hash often to point to the latest version. + +For more information about how to create `CMakeLists.txt` files, see the +[CMake Tutorial](https://cmake.org/cmake/help/latest/guide/tutorial/index.html). + +## Create and run a binary + +With GoogleTest declared as a dependency, you can use GoogleTest code within +your own project. + +As an example, create a file named `hello_test.cc` in your `my_project` +directory with the following contents: + +```cpp +#include + +// Demonstrate some basic assertions. +TEST(HelloTest, BasicAssertions) { + // Expect two strings not to be equal. + EXPECT_STRNE("hello", "world"); + // Expect equality. + EXPECT_EQ(7 * 6, 42); +} +``` + +GoogleTest provides [assertions](primer.md#assertions) that you use to test the +behavior of your code. The above sample includes the main GoogleTest header file +and demonstrates some basic assertions. + +To build the code, add the following to the end of your `CMakeLists.txt` file: + +```cmake +enable_testing() + +add_executable( + hello_test + hello_test.cc +) +target_link_libraries( + hello_test + GTest::gtest_main +) + +include(GoogleTest) +gtest_discover_tests(hello_test) +``` + +The above configuration enables testing in CMake, declares the C++ test binary +you want to build (`hello_test`), and links it to GoogleTest (`gtest_main`). The +last two lines enable CMake's test runner to discover the tests included in the +binary, using the +[`GoogleTest` CMake module](https://cmake.org/cmake/help/git-stage/module/GoogleTest.html). + +Now you can build and run your test: + +
+my_project$ cmake -S . -B build
+-- The C compiler identification is GNU 10.2.1
+-- The CXX compiler identification is GNU 10.2.1
+...
+-- Build files have been written to: .../my_project/build
+
+my_project$ cmake --build build
+Scanning dependencies of target gtest
+...
+[100%] Built target gmock_main
+
+my_project$ cd build && ctest
+Test project .../my_project/build
+    Start 1: HelloTest.BasicAssertions
+1/1 Test #1: HelloTest.BasicAssertions ........   Passed    0.00 sec
+
+100% tests passed, 0 tests failed out of 1
+
+Total Test time (real) =   0.01 sec
+
+ +Congratulations! You've successfully built and run a test binary using +GoogleTest. + +## Next steps + +* [Check out the Primer](primer.md) to start learning how to write simple + tests. +* [See the code samples](samples.md) for more examples showing how to use a + variety of GoogleTest features. diff --git a/3rdparty/googletest-1.13.0/docs/reference/actions.md b/3rdparty/googletest-1.13.0/docs/reference/actions.md new file mode 100644 index 0000000000000000000000000000000000000000..ab81a129eff692d513b27c155abed96dd30f8db6 --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/reference/actions.md @@ -0,0 +1,115 @@ +# Actions Reference + +[**Actions**](../gmock_for_dummies.md#actions-what-should-it-do) specify what a +mock function should do when invoked. This page lists the built-in actions +provided by GoogleTest. All actions are defined in the `::testing` namespace. + +## Returning a Value + +| Action | Description | +| :-------------------------------- | :-------------------------------------------- | +| `Return()` | Return from a `void` mock function. | +| `Return(value)` | Return `value`. If the type of `value` is different to the mock function's return type, `value` is converted to the latter type at the time the expectation is set, not when the action is executed. | +| `ReturnArg()` | Return the `N`-th (0-based) argument. | +| `ReturnNew(a1, ..., ak)` | Return `new T(a1, ..., ak)`; a different object is created each time. | +| `ReturnNull()` | Return a null pointer. | +| `ReturnPointee(ptr)` | Return the value pointed to by `ptr`. | +| `ReturnRef(variable)` | Return a reference to `variable`. | +| `ReturnRefOfCopy(value)` | Return a reference to a copy of `value`; the copy lives as long as the action. | +| `ReturnRoundRobin({a1, ..., ak})` | Each call will return the next `ai` in the list, starting at the beginning when the end of the list is reached. | + +## Side Effects + +| Action | Description | +| :--------------------------------- | :-------------------------------------- | +| `Assign(&variable, value)` | Assign `value` to variable. | +| `DeleteArg()` | Delete the `N`-th (0-based) argument, which must be a pointer. | +| `SaveArg(pointer)` | Save the `N`-th (0-based) argument to `*pointer`. | +| `SaveArgPointee(pointer)` | Save the value pointed to by the `N`-th (0-based) argument to `*pointer`. | +| `SetArgReferee(value)` | Assign `value` to the variable referenced by the `N`-th (0-based) argument. | +| `SetArgPointee(value)` | Assign `value` to the variable pointed by the `N`-th (0-based) argument. | +| `SetArgumentPointee(value)` | Same as `SetArgPointee(value)`. Deprecated. Will be removed in v1.7.0. | +| `SetArrayArgument(first, last)` | Copies the elements in source range [`first`, `last`) to the array pointed to by the `N`-th (0-based) argument, which can be either a pointer or an iterator. The action does not take ownership of the elements in the source range. | +| `SetErrnoAndReturn(error, value)` | Set `errno` to `error` and return `value`. | +| `Throw(exception)` | Throws the given exception, which can be any copyable value. Available since v1.1.0. | + +## Using a Function, Functor, or Lambda as an Action + +In the following, by "callable" we mean a free function, `std::function`, +functor, or lambda. + +| Action | Description | +| :---------------------------------- | :------------------------------------- | +| `f` | Invoke `f` with the arguments passed to the mock function, where `f` is a callable. | +| `Invoke(f)` | Invoke `f` with the arguments passed to the mock function, where `f` can be a global/static function or a functor. | +| `Invoke(object_pointer, &class::method)` | Invoke the method on the object with the arguments passed to the mock function. | +| `InvokeWithoutArgs(f)` | Invoke `f`, which can be a global/static function or a functor. `f` must take no arguments. | +| `InvokeWithoutArgs(object_pointer, &class::method)` | Invoke the method on the object, which takes no arguments. | +| `InvokeArgument(arg1, arg2, ..., argk)` | Invoke the mock function's `N`-th (0-based) argument, which must be a function or a functor, with the `k` arguments. | + +The return value of the invoked function is used as the return value of the +action. + +When defining a callable to be used with `Invoke*()`, you can declare any unused +parameters as `Unused`: + +```cpp +using ::testing::Invoke; +double Distance(Unused, double x, double y) { return sqrt(x*x + y*y); } +... +EXPECT_CALL(mock, Foo("Hi", _, _)).WillOnce(Invoke(Distance)); +``` + +`Invoke(callback)` and `InvokeWithoutArgs(callback)` take ownership of +`callback`, which must be permanent. The type of `callback` must be a base +callback type instead of a derived one, e.g. + +```cpp + BlockingClosure* done = new BlockingClosure; + ... Invoke(done) ...; // This won't compile! + + Closure* done2 = new BlockingClosure; + ... Invoke(done2) ...; // This works. +``` + +In `InvokeArgument(...)`, if an argument needs to be passed by reference, +wrap it inside `std::ref()`. For example, + +```cpp +using ::testing::InvokeArgument; +... +InvokeArgument<2>(5, string("Hi"), std::ref(foo)) +``` + +calls the mock function's #2 argument, passing to it `5` and `string("Hi")` by +value, and `foo` by reference. + +## Default Action + +| Action | Description | +| :------------ | :----------------------------------------------------- | +| `DoDefault()` | Do the default action (specified by `ON_CALL()` or the built-in one). | + +{: .callout .note} +**Note:** due to technical reasons, `DoDefault()` cannot be used inside a +composite action - trying to do so will result in a run-time error. + +## Composite Actions + +| Action | Description | +| :----------------------------- | :------------------------------------------ | +| `DoAll(a1, a2, ..., an)` | Do all actions `a1` to `an` and return the result of `an` in each invocation. The first `n - 1` sub-actions must return void and will receive a readonly view of the arguments. | +| `IgnoreResult(a)` | Perform action `a` and ignore its result. `a` must not return void. | +| `WithArg(a)` | Pass the `N`-th (0-based) argument of the mock function to action `a` and perform it. | +| `WithArgs(a)` | Pass the selected (0-based) arguments of the mock function to action `a` and perform it. | +| `WithoutArgs(a)` | Perform action `a` without any arguments. | + +## Defining Actions + +| Macro | Description | +| :--------------------------------- | :-------------------------------------- | +| `ACTION(Sum) { return arg0 + arg1; }` | Defines an action `Sum()` to return the sum of the mock function's argument #0 and #1. | +| `ACTION_P(Plus, n) { return arg0 + n; }` | Defines an action `Plus(n)` to return the sum of the mock function's argument #0 and `n`. | +| `ACTION_Pk(Foo, p1, ..., pk) { statements; }` | Defines a parameterized action `Foo(p1, ..., pk)` to execute the given `statements`. | + +The `ACTION*` macros cannot be used inside a function or class. diff --git a/3rdparty/googletest-1.13.0/docs/reference/assertions.md b/3rdparty/googletest-1.13.0/docs/reference/assertions.md new file mode 100644 index 0000000000000000000000000000000000000000..7bf03a3dde17857dfe7b508f14daa60d73bdac19 --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/reference/assertions.md @@ -0,0 +1,633 @@ +# Assertions Reference + +This page lists the assertion macros provided by GoogleTest for verifying code +behavior. To use them, include the header `gtest/gtest.h`. + +The majority of the macros listed below come as a pair with an `EXPECT_` variant +and an `ASSERT_` variant. Upon failure, `EXPECT_` macros generate nonfatal +failures and allow the current function to continue running, while `ASSERT_` +macros generate fatal failures and abort the current function. + +All assertion macros support streaming a custom failure message into them with +the `<<` operator, for example: + +```cpp +EXPECT_TRUE(my_condition) << "My condition is not true"; +``` + +Anything that can be streamed to an `ostream` can be streamed to an assertion +macro—in particular, C strings and string objects. If a wide string (`wchar_t*`, +`TCHAR*` in `UNICODE` mode on Windows, or `std::wstring`) is streamed to an +assertion, it will be translated to UTF-8 when printed. + +## Explicit Success and Failure {#success-failure} + +The assertions in this section generate a success or failure directly instead of +testing a value or expression. These are useful when control flow, rather than a +Boolean expression, determines the test's success or failure, as shown by the +following example: + +```c++ +switch(expression) { + case 1: + ... some checks ... + case 2: + ... some other checks ... + default: + FAIL() << "We shouldn't get here."; +} +``` + +### SUCCEED {#SUCCEED} + +`SUCCEED()` + +Generates a success. This *does not* make the overall test succeed. A test is +considered successful only if none of its assertions fail during its execution. + +The `SUCCEED` assertion is purely documentary and currently doesn't generate any +user-visible output. However, we may add `SUCCEED` messages to GoogleTest output +in the future. + +### FAIL {#FAIL} + +`FAIL()` + +Generates a fatal failure, which returns from the current function. + +Can only be used in functions that return `void`. See +[Assertion Placement](../advanced.md#assertion-placement) for more information. + +### ADD_FAILURE {#ADD_FAILURE} + +`ADD_FAILURE()` + +Generates a nonfatal failure, which allows the current function to continue +running. + +### ADD_FAILURE_AT {#ADD_FAILURE_AT} + +`ADD_FAILURE_AT(`*`file_path`*`,`*`line_number`*`)` + +Generates a nonfatal failure at the file and line number specified. + +## Generalized Assertion {#generalized} + +The following assertion allows [matchers](matchers.md) to be used to verify +values. + +### EXPECT_THAT {#EXPECT_THAT} + +`EXPECT_THAT(`*`value`*`,`*`matcher`*`)` \ +`ASSERT_THAT(`*`value`*`,`*`matcher`*`)` + +Verifies that *`value`* matches the [matcher](matchers.md) *`matcher`*. + +For example, the following code verifies that the string `value1` starts with +`"Hello"`, `value2` matches a regular expression, and `value3` is between 5 and +10: + +```cpp +#include "gmock/gmock.h" + +using ::testing::AllOf; +using ::testing::Gt; +using ::testing::Lt; +using ::testing::MatchesRegex; +using ::testing::StartsWith; + +... +EXPECT_THAT(value1, StartsWith("Hello")); +EXPECT_THAT(value2, MatchesRegex("Line \\d+")); +ASSERT_THAT(value3, AllOf(Gt(5), Lt(10))); +``` + +Matchers enable assertions of this form to read like English and generate +informative failure messages. For example, if the above assertion on `value1` +fails, the resulting message will be similar to the following: + +``` +Value of: value1 + Actual: "Hi, world!" +Expected: starts with "Hello" +``` + +GoogleTest provides a built-in library of matchers—see the +[Matchers Reference](matchers.md). It is also possible to write your own +matchers—see [Writing New Matchers Quickly](../gmock_cook_book.md#NewMatchers). +The use of matchers makes `EXPECT_THAT` a powerful, extensible assertion. + +*The idea for this assertion was borrowed from Joe Walnes' Hamcrest project, +which adds `assertThat()` to JUnit.* + +## Boolean Conditions {#boolean} + +The following assertions test Boolean conditions. + +### EXPECT_TRUE {#EXPECT_TRUE} + +`EXPECT_TRUE(`*`condition`*`)` \ +`ASSERT_TRUE(`*`condition`*`)` + +Verifies that *`condition`* is true. + +### EXPECT_FALSE {#EXPECT_FALSE} + +`EXPECT_FALSE(`*`condition`*`)` \ +`ASSERT_FALSE(`*`condition`*`)` + +Verifies that *`condition`* is false. + +## Binary Comparison {#binary-comparison} + +The following assertions compare two values. The value arguments must be +comparable by the assertion's comparison operator, otherwise a compiler error +will result. + +If an argument supports the `<<` operator, it will be called to print the +argument when the assertion fails. Otherwise, GoogleTest will attempt to print +them in the best way it can—see +[Teaching GoogleTest How to Print Your Values](../advanced.md#teaching-googletest-how-to-print-your-values). + +Arguments are always evaluated exactly once, so it's OK for the arguments to +have side effects. However, the argument evaluation order is undefined and +programs should not depend on any particular argument evaluation order. + +These assertions work with both narrow and wide string objects (`string` and +`wstring`). + +See also the [Floating-Point Comparison](#floating-point) assertions to compare +floating-point numbers and avoid problems caused by rounding. + +### EXPECT_EQ {#EXPECT_EQ} + +`EXPECT_EQ(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_EQ(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`==`*`val2`*. + +Does pointer equality on pointers. If used on two C strings, it tests if they +are in the same memory location, not if they have the same value. Use +[`EXPECT_STREQ`](#EXPECT_STREQ) to compare C strings (e.g. `const char*`) by +value. + +When comparing a pointer to `NULL`, use `EXPECT_EQ(`*`ptr`*`, nullptr)` instead +of `EXPECT_EQ(`*`ptr`*`, NULL)`. + +### EXPECT_NE {#EXPECT_NE} + +`EXPECT_NE(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_NE(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`!=`*`val2`*. + +Does pointer equality on pointers. If used on two C strings, it tests if they +are in different memory locations, not if they have different values. Use +[`EXPECT_STRNE`](#EXPECT_STRNE) to compare C strings (e.g. `const char*`) by +value. + +When comparing a pointer to `NULL`, use `EXPECT_NE(`*`ptr`*`, nullptr)` instead +of `EXPECT_NE(`*`ptr`*`, NULL)`. + +### EXPECT_LT {#EXPECT_LT} + +`EXPECT_LT(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_LT(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`<`*`val2`*. + +### EXPECT_LE {#EXPECT_LE} + +`EXPECT_LE(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_LE(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`<=`*`val2`*. + +### EXPECT_GT {#EXPECT_GT} + +`EXPECT_GT(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_GT(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`>`*`val2`*. + +### EXPECT_GE {#EXPECT_GE} + +`EXPECT_GE(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_GE(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`>=`*`val2`*. + +## String Comparison {#c-strings} + +The following assertions compare two **C strings**. To compare two `string` +objects, use [`EXPECT_EQ`](#EXPECT_EQ) or [`EXPECT_NE`](#EXPECT_NE) instead. + +These assertions also accept wide C strings (`wchar_t*`). If a comparison of two +wide strings fails, their values will be printed as UTF-8 narrow strings. + +To compare a C string with `NULL`, use `EXPECT_EQ(`*`c_string`*`, nullptr)` or +`EXPECT_NE(`*`c_string`*`, nullptr)`. + +### EXPECT_STREQ {#EXPECT_STREQ} + +`EXPECT_STREQ(`*`str1`*`,`*`str2`*`)` \ +`ASSERT_STREQ(`*`str1`*`,`*`str2`*`)` + +Verifies that the two C strings *`str1`* and *`str2`* have the same contents. + +### EXPECT_STRNE {#EXPECT_STRNE} + +`EXPECT_STRNE(`*`str1`*`,`*`str2`*`)` \ +`ASSERT_STRNE(`*`str1`*`,`*`str2`*`)` + +Verifies that the two C strings *`str1`* and *`str2`* have different contents. + +### EXPECT_STRCASEEQ {#EXPECT_STRCASEEQ} + +`EXPECT_STRCASEEQ(`*`str1`*`,`*`str2`*`)` \ +`ASSERT_STRCASEEQ(`*`str1`*`,`*`str2`*`)` + +Verifies that the two C strings *`str1`* and *`str2`* have the same contents, +ignoring case. + +### EXPECT_STRCASENE {#EXPECT_STRCASENE} + +`EXPECT_STRCASENE(`*`str1`*`,`*`str2`*`)` \ +`ASSERT_STRCASENE(`*`str1`*`,`*`str2`*`)` + +Verifies that the two C strings *`str1`* and *`str2`* have different contents, +ignoring case. + +## Floating-Point Comparison {#floating-point} + +The following assertions compare two floating-point values. + +Due to rounding errors, it is very unlikely that two floating-point values will +match exactly, so `EXPECT_EQ` is not suitable. In general, for floating-point +comparison to make sense, the user needs to carefully choose the error bound. + +GoogleTest also provides assertions that use a default error bound based on +Units in the Last Place (ULPs). To learn more about ULPs, see the article +[Comparing Floating Point Numbers](https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/). + +### EXPECT_FLOAT_EQ {#EXPECT_FLOAT_EQ} + +`EXPECT_FLOAT_EQ(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_FLOAT_EQ(`*`val1`*`,`*`val2`*`)` + +Verifies that the two `float` values *`val1`* and *`val2`* are approximately +equal, to within 4 ULPs from each other. + +### EXPECT_DOUBLE_EQ {#EXPECT_DOUBLE_EQ} + +`EXPECT_DOUBLE_EQ(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_DOUBLE_EQ(`*`val1`*`,`*`val2`*`)` + +Verifies that the two `double` values *`val1`* and *`val2`* are approximately +equal, to within 4 ULPs from each other. + +### EXPECT_NEAR {#EXPECT_NEAR} + +`EXPECT_NEAR(`*`val1`*`,`*`val2`*`,`*`abs_error`*`)` \ +`ASSERT_NEAR(`*`val1`*`,`*`val2`*`,`*`abs_error`*`)` + +Verifies that the difference between *`val1`* and *`val2`* does not exceed the +absolute error bound *`abs_error`*. + +## Exception Assertions {#exceptions} + +The following assertions verify that a piece of code throws, or does not throw, +an exception. Usage requires exceptions to be enabled in the build environment. + +Note that the piece of code under test can be a compound statement, for example: + +```cpp +EXPECT_NO_THROW({ + int n = 5; + DoSomething(&n); +}); +``` + +### EXPECT_THROW {#EXPECT_THROW} + +`EXPECT_THROW(`*`statement`*`,`*`exception_type`*`)` \ +`ASSERT_THROW(`*`statement`*`,`*`exception_type`*`)` + +Verifies that *`statement`* throws an exception of type *`exception_type`*. + +### EXPECT_ANY_THROW {#EXPECT_ANY_THROW} + +`EXPECT_ANY_THROW(`*`statement`*`)` \ +`ASSERT_ANY_THROW(`*`statement`*`)` + +Verifies that *`statement`* throws an exception of any type. + +### EXPECT_NO_THROW {#EXPECT_NO_THROW} + +`EXPECT_NO_THROW(`*`statement`*`)` \ +`ASSERT_NO_THROW(`*`statement`*`)` + +Verifies that *`statement`* does not throw any exception. + +## Predicate Assertions {#predicates} + +The following assertions enable more complex predicates to be verified while +printing a more clear failure message than if `EXPECT_TRUE` were used alone. + +### EXPECT_PRED* {#EXPECT_PRED} + +`EXPECT_PRED1(`*`pred`*`,`*`val1`*`)` \ +`EXPECT_PRED2(`*`pred`*`,`*`val1`*`,`*`val2`*`)` \ +`EXPECT_PRED3(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`)` \ +`EXPECT_PRED4(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`)` \ +`EXPECT_PRED5(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`,`*`val5`*`)` + +`ASSERT_PRED1(`*`pred`*`,`*`val1`*`)` \ +`ASSERT_PRED2(`*`pred`*`,`*`val1`*`,`*`val2`*`)` \ +`ASSERT_PRED3(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`)` \ +`ASSERT_PRED4(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`)` \ +`ASSERT_PRED5(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`,`*`val5`*`)` + +Verifies that the predicate *`pred`* returns `true` when passed the given values +as arguments. + +The parameter *`pred`* is a function or functor that accepts as many arguments +as the corresponding macro accepts values. If *`pred`* returns `true` for the +given arguments, the assertion succeeds, otherwise the assertion fails. + +When the assertion fails, it prints the value of each argument. Arguments are +always evaluated exactly once. + +As an example, see the following code: + +```cpp +// Returns true if m and n have no common divisors except 1. +bool MutuallyPrime(int m, int n) { ... } +... +const int a = 3; +const int b = 4; +const int c = 10; +... +EXPECT_PRED2(MutuallyPrime, a, b); // Succeeds +EXPECT_PRED2(MutuallyPrime, b, c); // Fails +``` + +In the above example, the first assertion succeeds, and the second fails with +the following message: + +``` +MutuallyPrime(b, c) is false, where +b is 4 +c is 10 +``` + +Note that if the given predicate is an overloaded function or a function +template, the assertion macro might not be able to determine which version to +use, and it might be necessary to explicitly specify the type of the function. +For example, for a Boolean function `IsPositive()` overloaded to take either a +single `int` or `double` argument, it would be necessary to write one of the +following: + +```cpp +EXPECT_PRED1(static_cast(IsPositive), 5); +EXPECT_PRED1(static_cast(IsPositive), 3.14); +``` + +Writing simply `EXPECT_PRED1(IsPositive, 5);` would result in a compiler error. +Similarly, to use a template function, specify the template arguments: + +```cpp +template +bool IsNegative(T x) { + return x < 0; +} +... +EXPECT_PRED1(IsNegative, -5); // Must specify type for IsNegative +``` + +If a template has multiple parameters, wrap the predicate in parentheses so the +macro arguments are parsed correctly: + +```cpp +ASSERT_PRED2((MyPredicate), 5, 0); +``` + +### EXPECT_PRED_FORMAT* {#EXPECT_PRED_FORMAT} + +`EXPECT_PRED_FORMAT1(`*`pred_formatter`*`,`*`val1`*`)` \ +`EXPECT_PRED_FORMAT2(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`)` \ +`EXPECT_PRED_FORMAT3(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`)` \ +`EXPECT_PRED_FORMAT4(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`)` +\ +`EXPECT_PRED_FORMAT5(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`,`*`val5`*`)` + +`ASSERT_PRED_FORMAT1(`*`pred_formatter`*`,`*`val1`*`)` \ +`ASSERT_PRED_FORMAT2(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`)` \ +`ASSERT_PRED_FORMAT3(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`)` \ +`ASSERT_PRED_FORMAT4(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`)` +\ +`ASSERT_PRED_FORMAT5(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`,`*`val5`*`)` + +Verifies that the predicate *`pred_formatter`* succeeds when passed the given +values as arguments. + +The parameter *`pred_formatter`* is a *predicate-formatter*, which is a function +or functor with the signature: + +```cpp +testing::AssertionResult PredicateFormatter(const char* expr1, + const char* expr2, + ... + const char* exprn, + T1 val1, + T2 val2, + ... + Tn valn); +``` + +where *`val1`*, *`val2`*, ..., *`valn`* are the values of the predicate +arguments, and *`expr1`*, *`expr2`*, ..., *`exprn`* are the corresponding +expressions as they appear in the source code. The types `T1`, `T2`, ..., `Tn` +can be either value types or reference types; if an argument has type `T`, it +can be declared as either `T` or `const T&`, whichever is appropriate. For more +about the return type `testing::AssertionResult`, see +[Using a Function That Returns an AssertionResult](../advanced.md#using-a-function-that-returns-an-assertionresult). + +As an example, see the following code: + +```cpp +// Returns the smallest prime common divisor of m and n, +// or 1 when m and n are mutually prime. +int SmallestPrimeCommonDivisor(int m, int n) { ... } + +// Returns true if m and n have no common divisors except 1. +bool MutuallyPrime(int m, int n) { ... } + +// A predicate-formatter for asserting that two integers are mutually prime. +testing::AssertionResult AssertMutuallyPrime(const char* m_expr, + const char* n_expr, + int m, + int n) { + if (MutuallyPrime(m, n)) return testing::AssertionSuccess(); + + return testing::AssertionFailure() << m_expr << " and " << n_expr + << " (" << m << " and " << n << ") are not mutually prime, " + << "as they have a common divisor " << SmallestPrimeCommonDivisor(m, n); +} + +... +const int a = 3; +const int b = 4; +const int c = 10; +... +EXPECT_PRED_FORMAT2(AssertMutuallyPrime, a, b); // Succeeds +EXPECT_PRED_FORMAT2(AssertMutuallyPrime, b, c); // Fails +``` + +In the above example, the final assertion fails and the predicate-formatter +produces the following failure message: + +``` +b and c (4 and 10) are not mutually prime, as they have a common divisor 2 +``` + +## Windows HRESULT Assertions {#HRESULT} + +The following assertions test for `HRESULT` success or failure. For example: + +```cpp +CComPtr shell; +ASSERT_HRESULT_SUCCEEDED(shell.CoCreateInstance(L"Shell.Application")); +CComVariant empty; +ASSERT_HRESULT_SUCCEEDED(shell->ShellExecute(CComBSTR(url), empty, empty, empty, empty)); +``` + +The generated output contains the human-readable error message associated with +the returned `HRESULT` code. + +### EXPECT_HRESULT_SUCCEEDED {#EXPECT_HRESULT_SUCCEEDED} + +`EXPECT_HRESULT_SUCCEEDED(`*`expression`*`)` \ +`ASSERT_HRESULT_SUCCEEDED(`*`expression`*`)` + +Verifies that *`expression`* is a success `HRESULT`. + +### EXPECT_HRESULT_FAILED {#EXPECT_HRESULT_FAILED} + +`EXPECT_HRESULT_FAILED(`*`expression`*`)` \ +`EXPECT_HRESULT_FAILED(`*`expression`*`)` + +Verifies that *`expression`* is a failure `HRESULT`. + +## Death Assertions {#death} + +The following assertions verify that a piece of code causes the process to +terminate. For context, see [Death Tests](../advanced.md#death-tests). + +These assertions spawn a new process and execute the code under test in that +process. How that happens depends on the platform and the variable +`::testing::GTEST_FLAG(death_test_style)`, which is initialized from the +command-line flag `--gtest_death_test_style`. + +* On POSIX systems, `fork()` (or `clone()` on Linux) is used to spawn the + child, after which: + * If the variable's value is `"fast"`, the death test statement is + immediately executed. + * If the variable's value is `"threadsafe"`, the child process re-executes + the unit test binary just as it was originally invoked, but with some + extra flags to cause just the single death test under consideration to + be run. +* On Windows, the child is spawned using the `CreateProcess()` API, and + re-executes the binary to cause just the single death test under + consideration to be run - much like the `"threadsafe"` mode on POSIX. + +Other values for the variable are illegal and will cause the death test to fail. +Currently, the flag's default value is +**`"fast"`**. + +If the death test statement runs to completion without dying, the child process +will nonetheless terminate, and the assertion fails. + +Note that the piece of code under test can be a compound statement, for example: + +```cpp +EXPECT_DEATH({ + int n = 5; + DoSomething(&n); +}, "Error on line .* of DoSomething()"); +``` + +### EXPECT_DEATH {#EXPECT_DEATH} + +`EXPECT_DEATH(`*`statement`*`,`*`matcher`*`)` \ +`ASSERT_DEATH(`*`statement`*`,`*`matcher`*`)` + +Verifies that *`statement`* causes the process to terminate with a nonzero exit +status and produces `stderr` output that matches *`matcher`*. + +The parameter *`matcher`* is either a [matcher](matchers.md) for a `const +std::string&`, or a regular expression (see +[Regular Expression Syntax](../advanced.md#regular-expression-syntax))—a bare +string *`s`* (with no matcher) is treated as +[`ContainsRegex(s)`](matchers.md#string-matchers), **not** +[`Eq(s)`](matchers.md#generic-comparison). + +For example, the following code verifies that calling `DoSomething(42)` causes +the process to die with an error message that contains the text `My error`: + +```cpp +EXPECT_DEATH(DoSomething(42), "My error"); +``` + +### EXPECT_DEATH_IF_SUPPORTED {#EXPECT_DEATH_IF_SUPPORTED} + +`EXPECT_DEATH_IF_SUPPORTED(`*`statement`*`,`*`matcher`*`)` \ +`ASSERT_DEATH_IF_SUPPORTED(`*`statement`*`,`*`matcher`*`)` + +If death tests are supported, behaves the same as +[`EXPECT_DEATH`](#EXPECT_DEATH). Otherwise, verifies nothing. + +### EXPECT_DEBUG_DEATH {#EXPECT_DEBUG_DEATH} + +`EXPECT_DEBUG_DEATH(`*`statement`*`,`*`matcher`*`)` \ +`ASSERT_DEBUG_DEATH(`*`statement`*`,`*`matcher`*`)` + +In debug mode, behaves the same as [`EXPECT_DEATH`](#EXPECT_DEATH). When not in +debug mode (i.e. `NDEBUG` is defined), just executes *`statement`*. + +### EXPECT_EXIT {#EXPECT_EXIT} + +`EXPECT_EXIT(`*`statement`*`,`*`predicate`*`,`*`matcher`*`)` \ +`ASSERT_EXIT(`*`statement`*`,`*`predicate`*`,`*`matcher`*`)` + +Verifies that *`statement`* causes the process to terminate with an exit status +that satisfies *`predicate`*, and produces `stderr` output that matches +*`matcher`*. + +The parameter *`predicate`* is a function or functor that accepts an `int` exit +status and returns a `bool`. GoogleTest provides two predicates to handle common +cases: + +```cpp +// Returns true if the program exited normally with the given exit status code. +::testing::ExitedWithCode(exit_code); + +// Returns true if the program was killed by the given signal. +// Not available on Windows. +::testing::KilledBySignal(signal_number); +``` + +The parameter *`matcher`* is either a [matcher](matchers.md) for a `const +std::string&`, or a regular expression (see +[Regular Expression Syntax](../advanced.md#regular-expression-syntax))—a bare +string *`s`* (with no matcher) is treated as +[`ContainsRegex(s)`](matchers.md#string-matchers), **not** +[`Eq(s)`](matchers.md#generic-comparison). + +For example, the following code verifies that calling `NormalExit()` causes the +process to print a message containing the text `Success` to `stderr` and exit +with exit status code 0: + +```cpp +EXPECT_EXIT(NormalExit(), testing::ExitedWithCode(0), "Success"); +``` diff --git a/3rdparty/googletest-1.13.0/docs/reference/matchers.md b/3rdparty/googletest-1.13.0/docs/reference/matchers.md new file mode 100644 index 0000000000000000000000000000000000000000..9fb159275131504ec920303268297a373502ab49 --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/reference/matchers.md @@ -0,0 +1,290 @@ +# Matchers Reference + +A **matcher** matches a *single* argument. You can use it inside `ON_CALL()` or +`EXPECT_CALL()`, or use it to validate a value directly using two macros: + +| Macro | Description | +| :----------------------------------- | :------------------------------------ | +| `EXPECT_THAT(actual_value, matcher)` | Asserts that `actual_value` matches `matcher`. | +| `ASSERT_THAT(actual_value, matcher)` | The same as `EXPECT_THAT(actual_value, matcher)`, except that it generates a **fatal** failure. | + +{: .callout .warning} +**WARNING:** Equality matching via `EXPECT_THAT(actual_value, expected_value)` +is supported, however note that implicit conversions can cause surprising +results. For example, `EXPECT_THAT(some_bool, "some string")` will compile and +may pass unintentionally. + +**BEST PRACTICE:** Prefer to make the comparison explicit via +`EXPECT_THAT(actual_value, Eq(expected_value))` or `EXPECT_EQ(actual_value, +expected_value)`. + +Built-in matchers (where `argument` is the function argument, e.g. +`actual_value` in the example above, or when used in the context of +`EXPECT_CALL(mock_object, method(matchers))`, the arguments of `method`) are +divided into several categories. All matchers are defined in the `::testing` +namespace unless otherwise noted. + +## Wildcard + +Matcher | Description +:-------------------------- | :----------------------------------------------- +`_` | `argument` can be any value of the correct type. +`A()` or `An()` | `argument` can be any value of type `type`. + +## Generic Comparison + +| Matcher | Description | +| :--------------------- | :-------------------------------------------------- | +| `Eq(value)` or `value` | `argument == value` | +| `Ge(value)` | `argument >= value` | +| `Gt(value)` | `argument > value` | +| `Le(value)` | `argument <= value` | +| `Lt(value)` | `argument < value` | +| `Ne(value)` | `argument != value` | +| `IsFalse()` | `argument` evaluates to `false` in a Boolean context. | +| `IsTrue()` | `argument` evaluates to `true` in a Boolean context. | +| `IsNull()` | `argument` is a `NULL` pointer (raw or smart). | +| `NotNull()` | `argument` is a non-null pointer (raw or smart). | +| `Optional(m)` | `argument` is `optional<>` that contains a value matching `m`. (For testing whether an `optional<>` is set, check for equality with `nullopt`. You may need to use `Eq(nullopt)` if the inner type doesn't have `==`.)| +| `VariantWith(m)` | `argument` is `variant<>` that holds the alternative of type T with a value matching `m`. | +| `Ref(variable)` | `argument` is a reference to `variable`. | +| `TypedEq(value)` | `argument` has type `type` and is equal to `value`. You may need to use this instead of `Eq(value)` when the mock function is overloaded. | + +Except `Ref()`, these matchers make a *copy* of `value` in case it's modified or +destructed later. If the compiler complains that `value` doesn't have a public +copy constructor, try wrap it in `std::ref()`, e.g. +`Eq(std::ref(non_copyable_value))`. If you do that, make sure +`non_copyable_value` is not changed afterwards, or the meaning of your matcher +will be changed. + +`IsTrue` and `IsFalse` are useful when you need to use a matcher, or for types +that can be explicitly converted to Boolean, but are not implicitly converted to +Boolean. In other cases, you can use the basic +[`EXPECT_TRUE` and `EXPECT_FALSE`](assertions.md#boolean) assertions. + +## Floating-Point Matchers {#FpMatchers} + +| Matcher | Description | +| :------------------------------- | :--------------------------------- | +| `DoubleEq(a_double)` | `argument` is a `double` value approximately equal to `a_double`, treating two NaNs as unequal. | +| `FloatEq(a_float)` | `argument` is a `float` value approximately equal to `a_float`, treating two NaNs as unequal. | +| `NanSensitiveDoubleEq(a_double)` | `argument` is a `double` value approximately equal to `a_double`, treating two NaNs as equal. | +| `NanSensitiveFloatEq(a_float)` | `argument` is a `float` value approximately equal to `a_float`, treating two NaNs as equal. | +| `IsNan()` | `argument` is any floating-point type with a NaN value. | + +The above matchers use ULP-based comparison (the same as used in googletest). +They automatically pick a reasonable error bound based on the absolute value of +the expected value. `DoubleEq()` and `FloatEq()` conform to the IEEE standard, +which requires comparing two NaNs for equality to return false. The +`NanSensitive*` version instead treats two NaNs as equal, which is often what a +user wants. + +| Matcher | Description | +| :------------------------------------------------ | :----------------------- | +| `DoubleNear(a_double, max_abs_error)` | `argument` is a `double` value close to `a_double` (absolute error <= `max_abs_error`), treating two NaNs as unequal. | +| `FloatNear(a_float, max_abs_error)` | `argument` is a `float` value close to `a_float` (absolute error <= `max_abs_error`), treating two NaNs as unequal. | +| `NanSensitiveDoubleNear(a_double, max_abs_error)` | `argument` is a `double` value close to `a_double` (absolute error <= `max_abs_error`), treating two NaNs as equal. | +| `NanSensitiveFloatNear(a_float, max_abs_error)` | `argument` is a `float` value close to `a_float` (absolute error <= `max_abs_error`), treating two NaNs as equal. | + +## String Matchers + +The `argument` can be either a C string or a C++ string object: + +| Matcher | Description | +| :---------------------- | :------------------------------------------------- | +| `ContainsRegex(string)` | `argument` matches the given regular expression. | +| `EndsWith(suffix)` | `argument` ends with string `suffix`. | +| `HasSubstr(string)` | `argument` contains `string` as a sub-string. | +| `IsEmpty()` | `argument` is an empty string. | +| `MatchesRegex(string)` | `argument` matches the given regular expression with the match starting at the first character and ending at the last character. | +| `StartsWith(prefix)` | `argument` starts with string `prefix`. | +| `StrCaseEq(string)` | `argument` is equal to `string`, ignoring case. | +| `StrCaseNe(string)` | `argument` is not equal to `string`, ignoring case. | +| `StrEq(string)` | `argument` is equal to `string`. | +| `StrNe(string)` | `argument` is not equal to `string`. | +| `WhenBase64Unescaped(m)` | `argument` is a base-64 escaped string whose unescaped string matches `m`. | + +`ContainsRegex()` and `MatchesRegex()` take ownership of the `RE` object. They +use the regular expression syntax defined +[here](../advanced.md#regular-expression-syntax). All of these matchers, except +`ContainsRegex()` and `MatchesRegex()` work for wide strings as well. + +## Container Matchers + +Most STL-style containers support `==`, so you can use `Eq(expected_container)` +or simply `expected_container` to match a container exactly. If you want to +write the elements in-line, match them more flexibly, or get more informative +messages, you can use: + +| Matcher | Description | +| :---------------------------------------- | :------------------------------- | +| `BeginEndDistanceIs(m)` | `argument` is a container whose `begin()` and `end()` iterators are separated by a number of increments matching `m`. E.g. `BeginEndDistanceIs(2)` or `BeginEndDistanceIs(Lt(2))`. For containers that define a `size()` method, `SizeIs(m)` may be more efficient. | +| `ContainerEq(container)` | The same as `Eq(container)` except that the failure message also includes which elements are in one container but not the other. | +| `Contains(e)` | `argument` contains an element that matches `e`, which can be either a value or a matcher. | +| `Contains(e).Times(n)` | `argument` contains elements that match `e`, which can be either a value or a matcher, and the number of matches is `n`, which can be either a value or a matcher. Unlike the plain `Contains` and `Each` this allows to check for arbitrary occurrences including testing for absence with `Contains(e).Times(0)`. | +| `Each(e)` | `argument` is a container where *every* element matches `e`, which can be either a value or a matcher. | +| `ElementsAre(e0, e1, ..., en)` | `argument` has `n + 1` elements, where the *i*-th element matches `ei`, which can be a value or a matcher. | +| `ElementsAreArray({e0, e1, ..., en})`, `ElementsAreArray(a_container)`, `ElementsAreArray(begin, end)`, `ElementsAreArray(array)`, or `ElementsAreArray(array, count)` | The same as `ElementsAre()` except that the expected element values/matchers come from an initializer list, STL-style container, iterator range, or C-style array. | +| `IsEmpty()` | `argument` is an empty container (`container.empty()`). | +| `IsSubsetOf({e0, e1, ..., en})`, `IsSubsetOf(a_container)`, `IsSubsetOf(begin, end)`, `IsSubsetOf(array)`, or `IsSubsetOf(array, count)` | `argument` matches `UnorderedElementsAre(x0, x1, ..., xk)` for some subset `{x0, x1, ..., xk}` of the expected matchers. | +| `IsSupersetOf({e0, e1, ..., en})`, `IsSupersetOf(a_container)`, `IsSupersetOf(begin, end)`, `IsSupersetOf(array)`, or `IsSupersetOf(array, count)` | Some subset of `argument` matches `UnorderedElementsAre(`expected matchers`)`. | +| `Pointwise(m, container)`, `Pointwise(m, {e0, e1, ..., en})` | `argument` contains the same number of elements as in `container`, and for all i, (the i-th element in `argument`, the i-th element in `container`) match `m`, which is a matcher on 2-tuples. E.g. `Pointwise(Le(), upper_bounds)` verifies that each element in `argument` doesn't exceed the corresponding element in `upper_bounds`. See more detail below. | +| `SizeIs(m)` | `argument` is a container whose size matches `m`. E.g. `SizeIs(2)` or `SizeIs(Lt(2))`. | +| `UnorderedElementsAre(e0, e1, ..., en)` | `argument` has `n + 1` elements, and under *some* permutation of the elements, each element matches an `ei` (for a different `i`), which can be a value or a matcher. | +| `UnorderedElementsAreArray({e0, e1, ..., en})`, `UnorderedElementsAreArray(a_container)`, `UnorderedElementsAreArray(begin, end)`, `UnorderedElementsAreArray(array)`, or `UnorderedElementsAreArray(array, count)` | The same as `UnorderedElementsAre()` except that the expected element values/matchers come from an initializer list, STL-style container, iterator range, or C-style array. | +| `UnorderedPointwise(m, container)`, `UnorderedPointwise(m, {e0, e1, ..., en})` | Like `Pointwise(m, container)`, but ignores the order of elements. | +| `WhenSorted(m)` | When `argument` is sorted using the `<` operator, it matches container matcher `m`. E.g. `WhenSorted(ElementsAre(1, 2, 3))` verifies that `argument` contains elements 1, 2, and 3, ignoring order. | +| `WhenSortedBy(comparator, m)` | The same as `WhenSorted(m)`, except that the given comparator instead of `<` is used to sort `argument`. E.g. `WhenSortedBy(std::greater(), ElementsAre(3, 2, 1))`. | + +**Notes:** + +* These matchers can also match: + 1. a native array passed by reference (e.g. in `Foo(const int (&a)[5])`), + and + 2. an array passed as a pointer and a count (e.g. in `Bar(const T* buffer, + int len)` -- see [Multi-argument Matchers](#MultiArgMatchers)). +* The array being matched may be multi-dimensional (i.e. its elements can be + arrays). +* `m` in `Pointwise(m, ...)` and `UnorderedPointwise(m, ...)` should be a + matcher for `::std::tuple` where `T` and `U` are the element type of + the actual container and the expected container, respectively. For example, + to compare two `Foo` containers where `Foo` doesn't support `operator==`, + one might write: + + ```cpp + MATCHER(FooEq, "") { + return std::get<0>(arg).Equals(std::get<1>(arg)); + } + ... + EXPECT_THAT(actual_foos, Pointwise(FooEq(), expected_foos)); + ``` + +## Member Matchers + +| Matcher | Description | +| :------------------------------ | :----------------------------------------- | +| `Field(&class::field, m)` | `argument.field` (or `argument->field` when `argument` is a plain pointer) matches matcher `m`, where `argument` is an object of type _class_. | +| `Field(field_name, &class::field, m)` | The same as the two-parameter version, but provides a better error message. | +| `Key(e)` | `argument.first` matches `e`, which can be either a value or a matcher. E.g. `Contains(Key(Le(5)))` can verify that a `map` contains a key `<= 5`. | +| `Pair(m1, m2)` | `argument` is an `std::pair` whose `first` field matches `m1` and `second` field matches `m2`. | +| `FieldsAre(m...)` | `argument` is a compatible object where each field matches piecewise with the matchers `m...`. A compatible object is any that supports the `std::tuple_size`+`get(obj)` protocol. In C++17 and up this also supports types compatible with structured bindings, like aggregates. | +| `Property(&class::property, m)` | `argument.property()` (or `argument->property()` when `argument` is a plain pointer) matches matcher `m`, where `argument` is an object of type _class_. The method `property()` must take no argument and be declared as `const`. | +| `Property(property_name, &class::property, m)` | The same as the two-parameter version, but provides a better error message. + +**Notes:** + +* You can use `FieldsAre()` to match any type that supports structured + bindings, such as `std::tuple`, `std::pair`, `std::array`, and aggregate + types. For example: + + ```cpp + std::tuple my_tuple{7, "hello world"}; + EXPECT_THAT(my_tuple, FieldsAre(Ge(0), HasSubstr("hello"))); + + struct MyStruct { + int value = 42; + std::string greeting = "aloha"; + }; + MyStruct s; + EXPECT_THAT(s, FieldsAre(42, "aloha")); + ``` + +* Don't use `Property()` against member functions that you do not own, because + taking addresses of functions is fragile and generally not part of the + contract of the function. + +## Matching the Result of a Function, Functor, or Callback + +| Matcher | Description | +| :--------------- | :------------------------------------------------ | +| `ResultOf(f, m)` | `f(argument)` matches matcher `m`, where `f` is a function or functor. | +| `ResultOf(result_description, f, m)` | The same as the two-parameter version, but provides a better error message. + +## Pointer Matchers + +| Matcher | Description | +| :------------------------ | :---------------------------------------------- | +| `Address(m)` | the result of `std::addressof(argument)` matches `m`. | +| `Pointee(m)` | `argument` (either a smart pointer or a raw pointer) points to a value that matches matcher `m`. | +| `Pointer(m)` | `argument` (either a smart pointer or a raw pointer) contains a pointer that matches `m`. `m` will match against the raw pointer regardless of the type of `argument`. | +| `WhenDynamicCastTo(m)` | when `argument` is passed through `dynamic_cast()`, it matches matcher `m`. | + +## Multi-argument Matchers {#MultiArgMatchers} + +Technically, all matchers match a *single* value. A "multi-argument" matcher is +just one that matches a *tuple*. The following matchers can be used to match a +tuple `(x, y)`: + +Matcher | Description +:------ | :---------- +`Eq()` | `x == y` +`Ge()` | `x >= y` +`Gt()` | `x > y` +`Le()` | `x <= y` +`Lt()` | `x < y` +`Ne()` | `x != y` + +You can use the following selectors to pick a subset of the arguments (or +reorder them) to participate in the matching: + +| Matcher | Description | +| :------------------------- | :---------------------------------------------- | +| `AllArgs(m)` | Equivalent to `m`. Useful as syntactic sugar in `.With(AllArgs(m))`. | +| `Args(m)` | The tuple of the `k` selected (using 0-based indices) arguments matches `m`, e.g. `Args<1, 2>(Eq())`. | + +## Composite Matchers + +You can make a matcher from one or more other matchers: + +| Matcher | Description | +| :------------------------------- | :-------------------------------------- | +| `AllOf(m1, m2, ..., mn)` | `argument` matches all of the matchers `m1` to `mn`. | +| `AllOfArray({m0, m1, ..., mn})`, `AllOfArray(a_container)`, `AllOfArray(begin, end)`, `AllOfArray(array)`, or `AllOfArray(array, count)` | The same as `AllOf()` except that the matchers come from an initializer list, STL-style container, iterator range, or C-style array. | +| `AnyOf(m1, m2, ..., mn)` | `argument` matches at least one of the matchers `m1` to `mn`. | +| `AnyOfArray({m0, m1, ..., mn})`, `AnyOfArray(a_container)`, `AnyOfArray(begin, end)`, `AnyOfArray(array)`, or `AnyOfArray(array, count)` | The same as `AnyOf()` except that the matchers come from an initializer list, STL-style container, iterator range, or C-style array. | +| `Not(m)` | `argument` doesn't match matcher `m`. | +| `Conditional(cond, m1, m2)` | Matches matcher `m1` if `cond` evaluates to true, else matches `m2`.| + +## Adapters for Matchers + +| Matcher | Description | +| :---------------------- | :------------------------------------ | +| `MatcherCast(m)` | casts matcher `m` to type `Matcher`. | +| `SafeMatcherCast(m)` | [safely casts](../gmock_cook_book.md#SafeMatcherCast) matcher `m` to type `Matcher`. | +| `Truly(predicate)` | `predicate(argument)` returns something considered by C++ to be true, where `predicate` is a function or functor. | + +`AddressSatisfies(callback)` and `Truly(callback)` take ownership of `callback`, +which must be a permanent callback. + +## Using Matchers as Predicates {#MatchersAsPredicatesCheat} + +| Matcher | Description | +| :---------------------------- | :------------------------------------------ | +| `Matches(m)(value)` | evaluates to `true` if `value` matches `m`. You can use `Matches(m)` alone as a unary functor. | +| `ExplainMatchResult(m, value, result_listener)` | evaluates to `true` if `value` matches `m`, explaining the result to `result_listener`. | +| `Value(value, m)` | evaluates to `true` if `value` matches `m`. | + +## Defining Matchers + +| Macro | Description | +| :----------------------------------- | :------------------------------------ | +| `MATCHER(IsEven, "") { return (arg % 2) == 0; }` | Defines a matcher `IsEven()` to match an even number. | +| `MATCHER_P(IsDivisibleBy, n, "") { *result_listener << "where the remainder is " << (arg % n); return (arg % n) == 0; }` | Defines a matcher `IsDivisibleBy(n)` to match a number divisible by `n`. | +| `MATCHER_P2(IsBetween, a, b, absl::StrCat(negation ? "isn't" : "is", " between ", PrintToString(a), " and ", PrintToString(b))) { return a <= arg && arg <= b; }` | Defines a matcher `IsBetween(a, b)` to match a value in the range [`a`, `b`]. | + +**Notes:** + +1. The `MATCHER*` macros cannot be used inside a function or class. +2. The matcher body must be *purely functional* (i.e. it cannot have any side + effect, and the result must not depend on anything other than the value + being matched and the matcher parameters). +3. You can use `PrintToString(x)` to convert a value `x` of any type to a + string. +4. You can use `ExplainMatchResult()` in a custom matcher to wrap another + matcher, for example: + + ```cpp + MATCHER_P(NestedPropertyMatches, matcher, "") { + return ExplainMatchResult(matcher, arg.nested().property(), result_listener); + } + ``` diff --git a/3rdparty/googletest-1.13.0/docs/reference/mocking.md b/3rdparty/googletest-1.13.0/docs/reference/mocking.md new file mode 100644 index 0000000000000000000000000000000000000000..e414ffbd0dea39b9a97989f2939943a4a87362bd --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/reference/mocking.md @@ -0,0 +1,589 @@ +# Mocking Reference + +This page lists the facilities provided by GoogleTest for creating and working +with mock objects. To use them, include the header +`gmock/gmock.h`. + +## Macros {#macros} + +GoogleTest defines the following macros for working with mocks. + +### MOCK_METHOD {#MOCK_METHOD} + +`MOCK_METHOD(`*`return_type`*`,`*`method_name`*`, (`*`args...`*`));` \ +`MOCK_METHOD(`*`return_type`*`,`*`method_name`*`, (`*`args...`*`), +(`*`specs...`*`));` + +Defines a mock method *`method_name`* with arguments `(`*`args...`*`)` and +return type *`return_type`* within a mock class. + +The parameters of `MOCK_METHOD` mirror the method declaration. The optional +fourth parameter *`specs...`* is a comma-separated list of qualifiers. The +following qualifiers are accepted: + +| Qualifier | Meaning | +| -------------------------- | -------------------------------------------- | +| `const` | Makes the mocked method a `const` method. Required if overriding a `const` method. | +| `override` | Marks the method with `override`. Recommended if overriding a `virtual` method. | +| `noexcept` | Marks the method with `noexcept`. Required if overriding a `noexcept` method. | +| `Calltype(`*`calltype`*`)` | Sets the call type for the method, for example `Calltype(STDMETHODCALLTYPE)`. Useful on Windows. | +| `ref(`*`qualifier`*`)` | Marks the method with the given reference qualifier, for example `ref(&)` or `ref(&&)`. Required if overriding a method that has a reference qualifier. | + +Note that commas in arguments prevent `MOCK_METHOD` from parsing the arguments +correctly if they are not appropriately surrounded by parentheses. See the +following example: + +```cpp +class MyMock { + public: + // The following 2 lines will not compile due to commas in the arguments: + MOCK_METHOD(std::pair, GetPair, ()); // Error! + MOCK_METHOD(bool, CheckMap, (std::map, bool)); // Error! + + // One solution - wrap arguments that contain commas in parentheses: + MOCK_METHOD((std::pair), GetPair, ()); + MOCK_METHOD(bool, CheckMap, ((std::map), bool)); + + // Another solution - use type aliases: + using BoolAndInt = std::pair; + MOCK_METHOD(BoolAndInt, GetPair, ()); + using MapIntDouble = std::map; + MOCK_METHOD(bool, CheckMap, (MapIntDouble, bool)); +}; +``` + +`MOCK_METHOD` must be used in the `public:` section of a mock class definition, +regardless of whether the method being mocked is `public`, `protected`, or +`private` in the base class. + +### EXPECT_CALL {#EXPECT_CALL} + +`EXPECT_CALL(`*`mock_object`*`,`*`method_name`*`(`*`matchers...`*`))` + +Creates an [expectation](../gmock_for_dummies.md#setting-expectations) that the +method *`method_name`* of the object *`mock_object`* is called with arguments +that match the given matchers *`matchers...`*. `EXPECT_CALL` must precede any +code that exercises the mock object. + +The parameter *`matchers...`* is a comma-separated list of +[matchers](../gmock_for_dummies.md#matchers-what-arguments-do-we-expect) that +correspond to each argument of the method *`method_name`*. The expectation will +apply only to calls of *`method_name`* whose arguments match all of the +matchers. If `(`*`matchers...`*`)` is omitted, the expectation behaves as if +each argument's matcher were a [wildcard matcher (`_`)](matchers.md#wildcard). +See the [Matchers Reference](matchers.md) for a list of all built-in matchers. + +The following chainable clauses can be used to modify the expectation, and they +must be used in the following order: + +```cpp +EXPECT_CALL(mock_object, method_name(matchers...)) + .With(multi_argument_matcher) // Can be used at most once + .Times(cardinality) // Can be used at most once + .InSequence(sequences...) // Can be used any number of times + .After(expectations...) // Can be used any number of times + .WillOnce(action) // Can be used any number of times + .WillRepeatedly(action) // Can be used at most once + .RetiresOnSaturation(); // Can be used at most once +``` + +See details for each modifier clause below. + +#### With {#EXPECT_CALL.With} + +`.With(`*`multi_argument_matcher`*`)` + +Restricts the expectation to apply only to mock function calls whose arguments +as a whole match the multi-argument matcher *`multi_argument_matcher`*. + +GoogleTest passes all of the arguments as one tuple into the matcher. The +parameter *`multi_argument_matcher`* must thus be a matcher of type +`Matcher>`, where `A1, ..., An` are the types of the +function arguments. + +For example, the following code sets the expectation that +`my_mock.SetPosition()` is called with any two arguments, the first argument +being less than the second: + +```cpp +using ::testing::_; +using ::testing::Lt; +... +EXPECT_CALL(my_mock, SetPosition(_, _)) + .With(Lt()); +``` + +GoogleTest provides some built-in matchers for 2-tuples, including the `Lt()` +matcher above. See [Multi-argument Matchers](matchers.md#MultiArgMatchers). + +The `With` clause can be used at most once on an expectation and must be the +first clause. + +#### Times {#EXPECT_CALL.Times} + +`.Times(`*`cardinality`*`)` + +Specifies how many times the mock function call is expected. + +The parameter *`cardinality`* represents the number of expected calls and can be +one of the following, all defined in the `::testing` namespace: + +| Cardinality | Meaning | +| ------------------- | --------------------------------------------------- | +| `AnyNumber()` | The function can be called any number of times. | +| `AtLeast(n)` | The function call is expected at least *n* times. | +| `AtMost(n)` | The function call is expected at most *n* times. | +| `Between(m, n)` | The function call is expected between *m* and *n* times, inclusive. | +| `Exactly(n)` or `n` | The function call is expected exactly *n* times. If *n* is 0, the call should never happen. | + +If the `Times` clause is omitted, GoogleTest infers the cardinality as follows: + +* If neither [`WillOnce`](#EXPECT_CALL.WillOnce) nor + [`WillRepeatedly`](#EXPECT_CALL.WillRepeatedly) are specified, the inferred + cardinality is `Times(1)`. +* If there are *n* `WillOnce` clauses and no `WillRepeatedly` clause, where + *n* >= 1, the inferred cardinality is `Times(n)`. +* If there are *n* `WillOnce` clauses and one `WillRepeatedly` clause, where + *n* >= 0, the inferred cardinality is `Times(AtLeast(n))`. + +The `Times` clause can be used at most once on an expectation. + +#### InSequence {#EXPECT_CALL.InSequence} + +`.InSequence(`*`sequences...`*`)` + +Specifies that the mock function call is expected in a certain sequence. + +The parameter *`sequences...`* is any number of [`Sequence`](#Sequence) objects. +Expected calls assigned to the same sequence are expected to occur in the order +the expectations are declared. + +For example, the following code sets the expectation that the `Reset()` method +of `my_mock` is called before both `GetSize()` and `Describe()`, and `GetSize()` +and `Describe()` can occur in any order relative to each other: + +```cpp +using ::testing::Sequence; +Sequence s1, s2; +... +EXPECT_CALL(my_mock, Reset()) + .InSequence(s1, s2); +EXPECT_CALL(my_mock, GetSize()) + .InSequence(s1); +EXPECT_CALL(my_mock, Describe()) + .InSequence(s2); +``` + +The `InSequence` clause can be used any number of times on an expectation. + +See also the [`InSequence` class](#InSequence). + +#### After {#EXPECT_CALL.After} + +`.After(`*`expectations...`*`)` + +Specifies that the mock function call is expected to occur after one or more +other calls. + +The parameter *`expectations...`* can be up to five +[`Expectation`](#Expectation) or [`ExpectationSet`](#ExpectationSet) objects. +The mock function call is expected to occur after all of the given expectations. + +For example, the following code sets the expectation that the `Describe()` +method of `my_mock` is called only after both `InitX()` and `InitY()` have been +called. + +```cpp +using ::testing::Expectation; +... +Expectation init_x = EXPECT_CALL(my_mock, InitX()); +Expectation init_y = EXPECT_CALL(my_mock, InitY()); +EXPECT_CALL(my_mock, Describe()) + .After(init_x, init_y); +``` + +The `ExpectationSet` object is helpful when the number of prerequisites for an +expectation is large or variable, for example: + +```cpp +using ::testing::ExpectationSet; +... +ExpectationSet all_inits; +// Collect all expectations of InitElement() calls +for (int i = 0; i < element_count; i++) { + all_inits += EXPECT_CALL(my_mock, InitElement(i)); +} +EXPECT_CALL(my_mock, Describe()) + .After(all_inits); // Expect Describe() call after all InitElement() calls +``` + +The `After` clause can be used any number of times on an expectation. + +#### WillOnce {#EXPECT_CALL.WillOnce} + +`.WillOnce(`*`action`*`)` + +Specifies the mock function's actual behavior when invoked, for a single +matching function call. + +The parameter *`action`* represents the +[action](../gmock_for_dummies.md#actions-what-should-it-do) that the function +call will perform. See the [Actions Reference](actions.md) for a list of +built-in actions. + +The use of `WillOnce` implicitly sets a cardinality on the expectation when +`Times` is not specified. See [`Times`](#EXPECT_CALL.Times). + +Each matching function call will perform the next action in the order declared. +For example, the following code specifies that `my_mock.GetNumber()` is expected +to be called exactly 3 times and will return `1`, `2`, and `3` respectively on +the first, second, and third calls: + +```cpp +using ::testing::Return; +... +EXPECT_CALL(my_mock, GetNumber()) + .WillOnce(Return(1)) + .WillOnce(Return(2)) + .WillOnce(Return(3)); +``` + +The `WillOnce` clause can be used any number of times on an expectation. Unlike +`WillRepeatedly`, the action fed to each `WillOnce` call will be called at most +once, so may be a move-only type and/or have an `&&`-qualified call operator. + +#### WillRepeatedly {#EXPECT_CALL.WillRepeatedly} + +`.WillRepeatedly(`*`action`*`)` + +Specifies the mock function's actual behavior when invoked, for all subsequent +matching function calls. Takes effect after the actions specified in the +[`WillOnce`](#EXPECT_CALL.WillOnce) clauses, if any, have been performed. + +The parameter *`action`* represents the +[action](../gmock_for_dummies.md#actions-what-should-it-do) that the function +call will perform. See the [Actions Reference](actions.md) for a list of +built-in actions. + +The use of `WillRepeatedly` implicitly sets a cardinality on the expectation +when `Times` is not specified. See [`Times`](#EXPECT_CALL.Times). + +If any `WillOnce` clauses have been specified, matching function calls will +perform those actions before the action specified by `WillRepeatedly`. See the +following example: + +```cpp +using ::testing::Return; +... +EXPECT_CALL(my_mock, GetName()) + .WillRepeatedly(Return("John Doe")); // Return "John Doe" on all calls + +EXPECT_CALL(my_mock, GetNumber()) + .WillOnce(Return(42)) // Return 42 on the first call + .WillRepeatedly(Return(7)); // Return 7 on all subsequent calls +``` + +The `WillRepeatedly` clause can be used at most once on an expectation. + +#### RetiresOnSaturation {#EXPECT_CALL.RetiresOnSaturation} + +`.RetiresOnSaturation()` + +Indicates that the expectation will no longer be active after the expected +number of matching function calls has been reached. + +The `RetiresOnSaturation` clause is only meaningful for expectations with an +upper-bounded cardinality. The expectation will *retire* (no longer match any +function calls) after it has been *saturated* (the upper bound has been +reached). See the following example: + +```cpp +using ::testing::_; +using ::testing::AnyNumber; +... +EXPECT_CALL(my_mock, SetNumber(_)) // Expectation 1 + .Times(AnyNumber()); +EXPECT_CALL(my_mock, SetNumber(7)) // Expectation 2 + .Times(2) + .RetiresOnSaturation(); +``` + +In the above example, the first two calls to `my_mock.SetNumber(7)` match +expectation 2, which then becomes inactive and no longer matches any calls. A +third call to `my_mock.SetNumber(7)` would then match expectation 1. Without +`RetiresOnSaturation()` on expectation 2, a third call to `my_mock.SetNumber(7)` +would match expectation 2 again, producing a failure since the limit of 2 calls +was exceeded. + +The `RetiresOnSaturation` clause can be used at most once on an expectation and +must be the last clause. + +### ON_CALL {#ON_CALL} + +`ON_CALL(`*`mock_object`*`,`*`method_name`*`(`*`matchers...`*`))` + +Defines what happens when the method *`method_name`* of the object +*`mock_object`* is called with arguments that match the given matchers +*`matchers...`*. Requires a modifier clause to specify the method's behavior. +*Does not* set any expectations that the method will be called. + +The parameter *`matchers...`* is a comma-separated list of +[matchers](../gmock_for_dummies.md#matchers-what-arguments-do-we-expect) that +correspond to each argument of the method *`method_name`*. The `ON_CALL` +specification will apply only to calls of *`method_name`* whose arguments match +all of the matchers. If `(`*`matchers...`*`)` is omitted, the behavior is as if +each argument's matcher were a [wildcard matcher (`_`)](matchers.md#wildcard). +See the [Matchers Reference](matchers.md) for a list of all built-in matchers. + +The following chainable clauses can be used to set the method's behavior, and +they must be used in the following order: + +```cpp +ON_CALL(mock_object, method_name(matchers...)) + .With(multi_argument_matcher) // Can be used at most once + .WillByDefault(action); // Required +``` + +See details for each modifier clause below. + +#### With {#ON_CALL.With} + +`.With(`*`multi_argument_matcher`*`)` + +Restricts the specification to only mock function calls whose arguments as a +whole match the multi-argument matcher *`multi_argument_matcher`*. + +GoogleTest passes all of the arguments as one tuple into the matcher. The +parameter *`multi_argument_matcher`* must thus be a matcher of type +`Matcher>`, where `A1, ..., An` are the types of the +function arguments. + +For example, the following code sets the default behavior when +`my_mock.SetPosition()` is called with any two arguments, the first argument +being less than the second: + +```cpp +using ::testing::_; +using ::testing::Lt; +using ::testing::Return; +... +ON_CALL(my_mock, SetPosition(_, _)) + .With(Lt()) + .WillByDefault(Return(true)); +``` + +GoogleTest provides some built-in matchers for 2-tuples, including the `Lt()` +matcher above. See [Multi-argument Matchers](matchers.md#MultiArgMatchers). + +The `With` clause can be used at most once with each `ON_CALL` statement. + +#### WillByDefault {#ON_CALL.WillByDefault} + +`.WillByDefault(`*`action`*`)` + +Specifies the default behavior of a matching mock function call. + +The parameter *`action`* represents the +[action](../gmock_for_dummies.md#actions-what-should-it-do) that the function +call will perform. See the [Actions Reference](actions.md) for a list of +built-in actions. + +For example, the following code specifies that by default, a call to +`my_mock.Greet()` will return `"hello"`: + +```cpp +using ::testing::Return; +... +ON_CALL(my_mock, Greet()) + .WillByDefault(Return("hello")); +``` + +The action specified by `WillByDefault` is superseded by the actions specified +on a matching `EXPECT_CALL` statement, if any. See the +[`WillOnce`](#EXPECT_CALL.WillOnce) and +[`WillRepeatedly`](#EXPECT_CALL.WillRepeatedly) clauses of `EXPECT_CALL`. + +The `WillByDefault` clause must be used exactly once with each `ON_CALL` +statement. + +## Classes {#classes} + +GoogleTest defines the following classes for working with mocks. + +### DefaultValue {#DefaultValue} + +`::testing::DefaultValue` + +Allows a user to specify the default value for a type `T` that is both copyable +and publicly destructible (i.e. anything that can be used as a function return +type). For mock functions with a return type of `T`, this default value is +returned from function calls that do not specify an action. + +Provides the static methods `Set()`, `SetFactory()`, and `Clear()` to manage the +default value: + +```cpp +// Sets the default value to be returned. T must be copy constructible. +DefaultValue::Set(value); + +// Sets a factory. Will be invoked on demand. T must be move constructible. +T MakeT(); +DefaultValue::SetFactory(&MakeT); + +// Unsets the default value. +DefaultValue::Clear(); +``` + +### NiceMock {#NiceMock} + +`::testing::NiceMock` + +Represents a mock object that suppresses warnings on +[uninteresting calls](../gmock_cook_book.md#uninteresting-vs-unexpected). The +template parameter `T` is any mock class, except for another `NiceMock`, +`NaggyMock`, or `StrictMock`. + +Usage of `NiceMock` is analogous to usage of `T`. `NiceMock` is a subclass +of `T`, so it can be used wherever an object of type `T` is accepted. In +addition, `NiceMock` can be constructed with any arguments that a constructor +of `T` accepts. + +For example, the following code suppresses warnings on the mock `my_mock` of +type `MockClass` if a method other than `DoSomething()` is called: + +```cpp +using ::testing::NiceMock; +... +NiceMock my_mock("some", "args"); +EXPECT_CALL(my_mock, DoSomething()); +... code that uses my_mock ... +``` + +`NiceMock` only works for mock methods defined using the `MOCK_METHOD` macro +directly in the definition of class `T`. If a mock method is defined in a base +class of `T`, a warning might still be generated. + +`NiceMock` might not work correctly if the destructor of `T` is not virtual. + +### NaggyMock {#NaggyMock} + +`::testing::NaggyMock` + +Represents a mock object that generates warnings on +[uninteresting calls](../gmock_cook_book.md#uninteresting-vs-unexpected). The +template parameter `T` is any mock class, except for another `NiceMock`, +`NaggyMock`, or `StrictMock`. + +Usage of `NaggyMock` is analogous to usage of `T`. `NaggyMock` is a +subclass of `T`, so it can be used wherever an object of type `T` is accepted. +In addition, `NaggyMock` can be constructed with any arguments that a +constructor of `T` accepts. + +For example, the following code generates warnings on the mock `my_mock` of type +`MockClass` if a method other than `DoSomething()` is called: + +```cpp +using ::testing::NaggyMock; +... +NaggyMock my_mock("some", "args"); +EXPECT_CALL(my_mock, DoSomething()); +... code that uses my_mock ... +``` + +Mock objects of type `T` by default behave the same way as `NaggyMock`. + +### StrictMock {#StrictMock} + +`::testing::StrictMock` + +Represents a mock object that generates test failures on +[uninteresting calls](../gmock_cook_book.md#uninteresting-vs-unexpected). The +template parameter `T` is any mock class, except for another `NiceMock`, +`NaggyMock`, or `StrictMock`. + +Usage of `StrictMock` is analogous to usage of `T`. `StrictMock` is a +subclass of `T`, so it can be used wherever an object of type `T` is accepted. +In addition, `StrictMock` can be constructed with any arguments that a +constructor of `T` accepts. + +For example, the following code generates a test failure on the mock `my_mock` +of type `MockClass` if a method other than `DoSomething()` is called: + +```cpp +using ::testing::StrictMock; +... +StrictMock my_mock("some", "args"); +EXPECT_CALL(my_mock, DoSomething()); +... code that uses my_mock ... +``` + +`StrictMock` only works for mock methods defined using the `MOCK_METHOD` +macro directly in the definition of class `T`. If a mock method is defined in a +base class of `T`, a failure might not be generated. + +`StrictMock` might not work correctly if the destructor of `T` is not +virtual. + +### Sequence {#Sequence} + +`::testing::Sequence` + +Represents a chronological sequence of expectations. See the +[`InSequence`](#EXPECT_CALL.InSequence) clause of `EXPECT_CALL` for usage. + +### InSequence {#InSequence} + +`::testing::InSequence` + +An object of this type causes all expectations encountered in its scope to be +put in an anonymous sequence. + +This allows more convenient expression of multiple expectations in a single +sequence: + +```cpp +using ::testing::InSequence; +{ + InSequence seq; + + // The following are expected to occur in the order declared. + EXPECT_CALL(...); + EXPECT_CALL(...); + ... + EXPECT_CALL(...); +} +``` + +The name of the `InSequence` object does not matter. + +### Expectation {#Expectation} + +`::testing::Expectation` + +Represents a mock function call expectation as created by +[`EXPECT_CALL`](#EXPECT_CALL): + +```cpp +using ::testing::Expectation; +Expectation my_expectation = EXPECT_CALL(...); +``` + +Useful for specifying sequences of expectations; see the +[`After`](#EXPECT_CALL.After) clause of `EXPECT_CALL`. + +### ExpectationSet {#ExpectationSet} + +`::testing::ExpectationSet` + +Represents a set of mock function call expectations. + +Use the `+=` operator to add [`Expectation`](#Expectation) objects to the set: + +```cpp +using ::testing::ExpectationSet; +ExpectationSet my_expectations; +my_expectations += EXPECT_CALL(...); +``` + +Useful for specifying sequences of expectations; see the +[`After`](#EXPECT_CALL.After) clause of `EXPECT_CALL`. diff --git a/3rdparty/googletest-1.13.0/docs/reference/testing.md b/3rdparty/googletest-1.13.0/docs/reference/testing.md new file mode 100644 index 0000000000000000000000000000000000000000..62cdcc1c6555542856533412608e2173e7ea9a0d --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/reference/testing.md @@ -0,0 +1,1431 @@ +# Testing Reference + + + +This page lists the facilities provided by GoogleTest for writing test programs. +To use them, include the header `gtest/gtest.h`. + +## Macros + +GoogleTest defines the following macros for writing tests. + +### TEST {#TEST} + +
+TEST(TestSuiteName, TestName) {
+  ... statements ...
+}
+
+ +Defines an individual test named *`TestName`* in the test suite +*`TestSuiteName`*, consisting of the given statements. + +Both arguments *`TestSuiteName`* and *`TestName`* must be valid C++ identifiers +and must not contain underscores (`_`). Tests in different test suites can have +the same individual name. + +The statements within the test body can be any code under test. +[Assertions](assertions.md) used within the test body determine the outcome of +the test. + +### TEST_F {#TEST_F} + +
+TEST_F(TestFixtureName, TestName) {
+  ... statements ...
+}
+
+ +Defines an individual test named *`TestName`* that uses the test fixture class +*`TestFixtureName`*. The test suite name is *`TestFixtureName`*. + +Both arguments *`TestFixtureName`* and *`TestName`* must be valid C++ +identifiers and must not contain underscores (`_`). *`TestFixtureName`* must be +the name of a test fixture class—see +[Test Fixtures](../primer.md#same-data-multiple-tests). + +The statements within the test body can be any code under test. +[Assertions](assertions.md) used within the test body determine the outcome of +the test. + +### TEST_P {#TEST_P} + +
+TEST_P(TestFixtureName, TestName) {
+  ... statements ...
+}
+
+ +Defines an individual value-parameterized test named *`TestName`* that uses the +test fixture class *`TestFixtureName`*. The test suite name is +*`TestFixtureName`*. + +Both arguments *`TestFixtureName`* and *`TestName`* must be valid C++ +identifiers and must not contain underscores (`_`). *`TestFixtureName`* must be +the name of a value-parameterized test fixture class—see +[Value-Parameterized Tests](../advanced.md#value-parameterized-tests). + +The statements within the test body can be any code under test. Within the test +body, the test parameter can be accessed with the `GetParam()` function (see +[`WithParamInterface`](#WithParamInterface)). For example: + +```cpp +TEST_P(MyTestSuite, DoesSomething) { + ... + EXPECT_TRUE(DoSomething(GetParam())); + ... +} +``` + +[Assertions](assertions.md) used within the test body determine the outcome of +the test. + +See also [`INSTANTIATE_TEST_SUITE_P`](#INSTANTIATE_TEST_SUITE_P). + +### INSTANTIATE_TEST_SUITE_P {#INSTANTIATE_TEST_SUITE_P} + +`INSTANTIATE_TEST_SUITE_P(`*`InstantiationName`*`,`*`TestSuiteName`*`,`*`param_generator`*`)` +\ +`INSTANTIATE_TEST_SUITE_P(`*`InstantiationName`*`,`*`TestSuiteName`*`,`*`param_generator`*`,`*`name_generator`*`)` + +Instantiates the value-parameterized test suite *`TestSuiteName`* (defined with +[`TEST_P`](#TEST_P)). + +The argument *`InstantiationName`* is a unique name for the instantiation of the +test suite, to distinguish between multiple instantiations. In test output, the +instantiation name is added as a prefix to the test suite name +*`TestSuiteName`*. + +The argument *`param_generator`* is one of the following GoogleTest-provided +functions that generate the test parameters, all defined in the `::testing` +namespace: + + + +| Parameter Generator | Behavior | +| ------------------- | ---------------------------------------------------- | +| `Range(begin, end [, step])` | Yields values `{begin, begin+step, begin+step+step, ...}`. The values do not include `end`. `step` defaults to 1. | +| `Values(v1, v2, ..., vN)` | Yields values `{v1, v2, ..., vN}`. | +| `ValuesIn(container)` or `ValuesIn(begin,end)` | Yields values from a C-style array, an STL-style container, or an iterator range `[begin, end)`. | +| `Bool()` | Yields sequence `{false, true}`. | +| `Combine(g1, g2, ..., gN)` | Yields as `std::tuple` *n*-tuples all combinations (Cartesian product) of the values generated by the given *n* generators `g1`, `g2`, ..., `gN`. | +| `ConvertGenerator(g)` | Yields values generated by generator `g`, `static_cast` to `T`. | +The optional last argument *`name_generator`* is a function or functor that +generates custom test name suffixes based on the test parameters. The function +must accept an argument of type +[`TestParamInfo`](#TestParamInfo) and return a `std::string`. +The test name suffix can only contain alphanumeric characters and underscores. +GoogleTest provides [`PrintToStringParamName`](#PrintToStringParamName), or a +custom function can be used for more control: + +```cpp +INSTANTIATE_TEST_SUITE_P( + MyInstantiation, MyTestSuite, + ::testing::Values(...), + [](const ::testing::TestParamInfo& info) { + // Can use info.param here to generate the test suffix + std::string name = ... + return name; + }); +``` + +For more information, see +[Value-Parameterized Tests](../advanced.md#value-parameterized-tests). + +See also +[`GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST`](#GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST). + +### TYPED_TEST_SUITE {#TYPED_TEST_SUITE} + +`TYPED_TEST_SUITE(`*`TestFixtureName`*`,`*`Types`*`)` + +Defines a typed test suite based on the test fixture *`TestFixtureName`*. The +test suite name is *`TestFixtureName`*. + +The argument *`TestFixtureName`* is a fixture class template, parameterized by a +type, for example: + +```cpp +template +class MyFixture : public ::testing::Test { + public: + ... + using List = std::list; + static T shared_; + T value_; +}; +``` + +The argument *`Types`* is a [`Types`](#Types) object representing the list of +types to run the tests on, for example: + +```cpp +using MyTypes = ::testing::Types; +TYPED_TEST_SUITE(MyFixture, MyTypes); +``` + +The type alias (`using` or `typedef`) is necessary for the `TYPED_TEST_SUITE` +macro to parse correctly. + +See also [`TYPED_TEST`](#TYPED_TEST) and +[Typed Tests](../advanced.md#typed-tests) for more information. + +### TYPED_TEST {#TYPED_TEST} + +
+TYPED_TEST(TestSuiteName, TestName) {
+  ... statements ...
+}
+
+ +Defines an individual typed test named *`TestName`* in the typed test suite +*`TestSuiteName`*. The test suite must be defined with +[`TYPED_TEST_SUITE`](#TYPED_TEST_SUITE). + +Within the test body, the special name `TypeParam` refers to the type parameter, +and `TestFixture` refers to the fixture class. See the following example: + +```cpp +TYPED_TEST(MyFixture, Example) { + // Inside a test, refer to the special name TypeParam to get the type + // parameter. Since we are inside a derived class template, C++ requires + // us to visit the members of MyFixture via 'this'. + TypeParam n = this->value_; + + // To visit static members of the fixture, add the 'TestFixture::' + // prefix. + n += TestFixture::shared_; + + // To refer to typedefs in the fixture, add the 'typename TestFixture::' + // prefix. The 'typename' is required to satisfy the compiler. + typename TestFixture::List values; + + values.push_back(n); + ... +} +``` + +For more information, see [Typed Tests](../advanced.md#typed-tests). + +### TYPED_TEST_SUITE_P {#TYPED_TEST_SUITE_P} + +`TYPED_TEST_SUITE_P(`*`TestFixtureName`*`)` + +Defines a type-parameterized test suite based on the test fixture +*`TestFixtureName`*. The test suite name is *`TestFixtureName`*. + +The argument *`TestFixtureName`* is a fixture class template, parameterized by a +type. See [`TYPED_TEST_SUITE`](#TYPED_TEST_SUITE) for an example. + +See also [`TYPED_TEST_P`](#TYPED_TEST_P) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests) for more +information. + +### TYPED_TEST_P {#TYPED_TEST_P} + +
+TYPED_TEST_P(TestSuiteName, TestName) {
+  ... statements ...
+}
+
+ +Defines an individual type-parameterized test named *`TestName`* in the +type-parameterized test suite *`TestSuiteName`*. The test suite must be defined +with [`TYPED_TEST_SUITE_P`](#TYPED_TEST_SUITE_P). + +Within the test body, the special name `TypeParam` refers to the type parameter, +and `TestFixture` refers to the fixture class. See [`TYPED_TEST`](#TYPED_TEST) +for an example. + +See also [`REGISTER_TYPED_TEST_SUITE_P`](#REGISTER_TYPED_TEST_SUITE_P) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests) for more +information. + +### REGISTER_TYPED_TEST_SUITE_P {#REGISTER_TYPED_TEST_SUITE_P} + +`REGISTER_TYPED_TEST_SUITE_P(`*`TestSuiteName`*`,`*`TestNames...`*`)` + +Registers the type-parameterized tests *`TestNames...`* of the test suite +*`TestSuiteName`*. The test suite and tests must be defined with +[`TYPED_TEST_SUITE_P`](#TYPED_TEST_SUITE_P) and [`TYPED_TEST_P`](#TYPED_TEST_P). + +For example: + +```cpp +// Define the test suite and tests. +TYPED_TEST_SUITE_P(MyFixture); +TYPED_TEST_P(MyFixture, HasPropertyA) { ... } +TYPED_TEST_P(MyFixture, HasPropertyB) { ... } + +// Register the tests in the test suite. +REGISTER_TYPED_TEST_SUITE_P(MyFixture, HasPropertyA, HasPropertyB); +``` + +See also [`INSTANTIATE_TYPED_TEST_SUITE_P`](#INSTANTIATE_TYPED_TEST_SUITE_P) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests) for more +information. + +### INSTANTIATE_TYPED_TEST_SUITE_P {#INSTANTIATE_TYPED_TEST_SUITE_P} + +`INSTANTIATE_TYPED_TEST_SUITE_P(`*`InstantiationName`*`,`*`TestSuiteName`*`,`*`Types`*`)` + +Instantiates the type-parameterized test suite *`TestSuiteName`*. The test suite +must be registered with +[`REGISTER_TYPED_TEST_SUITE_P`](#REGISTER_TYPED_TEST_SUITE_P). + +The argument *`InstantiationName`* is a unique name for the instantiation of the +test suite, to distinguish between multiple instantiations. In test output, the +instantiation name is added as a prefix to the test suite name +*`TestSuiteName`*. + +The argument *`Types`* is a [`Types`](#Types) object representing the list of +types to run the tests on, for example: + +```cpp +using MyTypes = ::testing::Types; +INSTANTIATE_TYPED_TEST_SUITE_P(MyInstantiation, MyFixture, MyTypes); +``` + +The type alias (`using` or `typedef`) is necessary for the +`INSTANTIATE_TYPED_TEST_SUITE_P` macro to parse correctly. + +For more information, see +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests). + +### FRIEND_TEST {#FRIEND_TEST} + +`FRIEND_TEST(`*`TestSuiteName`*`,`*`TestName`*`)` + +Within a class body, declares an individual test as a friend of the class, +enabling the test to access private class members. + +If the class is defined in a namespace, then in order to be friends of the +class, test fixtures and tests must be defined in the exact same namespace, +without inline or anonymous namespaces. + +For example, if the class definition looks like the following: + +```cpp +namespace my_namespace { + +class MyClass { + friend class MyClassTest; + FRIEND_TEST(MyClassTest, HasPropertyA); + FRIEND_TEST(MyClassTest, HasPropertyB); + ... definition of class MyClass ... +}; + +} // namespace my_namespace +``` + +Then the test code should look like: + +```cpp +namespace my_namespace { + +class MyClassTest : public ::testing::Test { + ... +}; + +TEST_F(MyClassTest, HasPropertyA) { ... } +TEST_F(MyClassTest, HasPropertyB) { ... } + +} // namespace my_namespace +``` + +See [Testing Private Code](../advanced.md#testing-private-code) for more +information. + +### SCOPED_TRACE {#SCOPED_TRACE} + +`SCOPED_TRACE(`*`message`*`)` + +Causes the current file name, line number, and the given message *`message`* to +be added to the failure message for each assertion failure that occurs in the +scope. + +For more information, see +[Adding Traces to Assertions](../advanced.md#adding-traces-to-assertions). + +See also the [`ScopedTrace` class](#ScopedTrace). + +### GTEST_SKIP {#GTEST_SKIP} + +`GTEST_SKIP()` + +Prevents further test execution at runtime. + +Can be used in individual test cases or in the `SetUp()` methods of test +environments or test fixtures (classes derived from the +[`Environment`](#Environment) or [`Test`](#Test) classes). If used in a global +test environment `SetUp()` method, it skips all tests in the test program. If +used in a test fixture `SetUp()` method, it skips all tests in the corresponding +test suite. + +Similar to assertions, `GTEST_SKIP` allows streaming a custom message into it. + +See [Skipping Test Execution](../advanced.md#skipping-test-execution) for more +information. + +### GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST {#GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST} + +`GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(`*`TestSuiteName`*`)` + +Allows the value-parameterized test suite *`TestSuiteName`* to be +uninstantiated. + +By default, every [`TEST_P`](#TEST_P) call without a corresponding +[`INSTANTIATE_TEST_SUITE_P`](#INSTANTIATE_TEST_SUITE_P) call causes a failing +test in the test suite `GoogleTestVerification`. +`GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST` suppresses this failure for the +given test suite. + +## Classes and types + +GoogleTest defines the following classes and types to help with writing tests. + +### AssertionResult {#AssertionResult} + +`::testing::AssertionResult` + +A class for indicating whether an assertion was successful. + +When the assertion wasn't successful, the `AssertionResult` object stores a +non-empty failure message that can be retrieved with the object's `message()` +method. + +To create an instance of this class, use one of the factory functions +[`AssertionSuccess()`](#AssertionSuccess) or +[`AssertionFailure()`](#AssertionFailure). + +### AssertionException {#AssertionException} + +`::testing::AssertionException` + +Exception which can be thrown from +[`TestEventListener::OnTestPartResult`](#TestEventListener::OnTestPartResult). + +### EmptyTestEventListener {#EmptyTestEventListener} + +`::testing::EmptyTestEventListener` + +Provides an empty implementation of all methods in the +[`TestEventListener`](#TestEventListener) interface, such that a subclass only +needs to override the methods it cares about. + +### Environment {#Environment} + +`::testing::Environment` + +Represents a global test environment. See +[Global Set-Up and Tear-Down](../advanced.md#global-set-up-and-tear-down). + +#### Protected Methods {#Environment-protected} + +##### SetUp {#Environment::SetUp} + +`virtual void Environment::SetUp()` + +Override this to define how to set up the environment. + +##### TearDown {#Environment::TearDown} + +`virtual void Environment::TearDown()` + +Override this to define how to tear down the environment. + +### ScopedTrace {#ScopedTrace} + +`::testing::ScopedTrace` + +An instance of this class causes a trace to be included in every test failure +message generated by code in the scope of the lifetime of the `ScopedTrace` +instance. The effect is undone with the destruction of the instance. + +The `ScopedTrace` constructor has the following form: + +```cpp +template +ScopedTrace(const char* file, int line, const T& message) +``` + +Example usage: + +```cpp +::testing::ScopedTrace trace("file.cc", 123, "message"); +``` + +The resulting trace includes the given source file path and line number, and the +given message. The `message` argument can be anything streamable to +`std::ostream`. + +See also [`SCOPED_TRACE`](#SCOPED_TRACE). + +### Test {#Test} + +`::testing::Test` + +The abstract class that all tests inherit from. `Test` is not copyable. + +#### Public Methods {#Test-public} + +##### SetUpTestSuite {#Test::SetUpTestSuite} + +`static void Test::SetUpTestSuite()` + +Performs shared setup for all tests in the test suite. GoogleTest calls +`SetUpTestSuite()` before running the first test in the test suite. + +##### TearDownTestSuite {#Test::TearDownTestSuite} + +`static void Test::TearDownTestSuite()` + +Performs shared teardown for all tests in the test suite. GoogleTest calls +`TearDownTestSuite()` after running the last test in the test suite. + +##### HasFatalFailure {#Test::HasFatalFailure} + +`static bool Test::HasFatalFailure()` + +Returns true if and only if the current test has a fatal failure. + +##### HasNonfatalFailure {#Test::HasNonfatalFailure} + +`static bool Test::HasNonfatalFailure()` + +Returns true if and only if the current test has a nonfatal failure. + +##### HasFailure {#Test::HasFailure} + +`static bool Test::HasFailure()` + +Returns true if and only if the current test has any failure, either fatal or +nonfatal. + +##### IsSkipped {#Test::IsSkipped} + +`static bool Test::IsSkipped()` + +Returns true if and only if the current test was skipped. + +##### RecordProperty {#Test::RecordProperty} + +`static void Test::RecordProperty(const std::string& key, const std::string& +value)` \ +`static void Test::RecordProperty(const std::string& key, int value)` + +Logs a property for the current test, test suite, or entire invocation of the +test program. Only the last value for a given key is logged. + +The key must be a valid XML attribute name, and cannot conflict with the ones +already used by GoogleTest (`name`, `file`, `line`, `status`, `time`, +`classname`, `type_param`, and `value_param`). + +`RecordProperty` is `public static` so it can be called from utility functions +that are not members of the test fixture. + +Calls to `RecordProperty` made during the lifespan of the test (from the moment +its constructor starts to the moment its destructor finishes) are output in XML +as attributes of the `` element. Properties recorded from a fixture's +`SetUpTestSuite` or `TearDownTestSuite` methods are logged as attributes of the +corresponding `` element. Calls to `RecordProperty` made in the +global context (before or after invocation of `RUN_ALL_TESTS` or from the +`SetUp`/`TearDown` methods of registered `Environment` objects) are output as +attributes of the `` element. + +#### Protected Methods {#Test-protected} + +##### SetUp {#Test::SetUp} + +`virtual void Test::SetUp()` + +Override this to perform test fixture setup. GoogleTest calls `SetUp()` before +running each individual test. + +##### TearDown {#Test::TearDown} + +`virtual void Test::TearDown()` + +Override this to perform test fixture teardown. GoogleTest calls `TearDown()` +after running each individual test. + +### TestWithParam {#TestWithParam} + +`::testing::TestWithParam` + +A convenience class which inherits from both [`Test`](#Test) and +[`WithParamInterface`](#WithParamInterface). + +### TestSuite {#TestSuite} + +Represents a test suite. `TestSuite` is not copyable. + +#### Public Methods {#TestSuite-public} + +##### name {#TestSuite::name} + +`const char* TestSuite::name() const` + +Gets the name of the test suite. + +##### type_param {#TestSuite::type_param} + +`const char* TestSuite::type_param() const` + +Returns the name of the parameter type, or `NULL` if this is not a typed or +type-parameterized test suite. See [Typed Tests](../advanced.md#typed-tests) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests). + +##### should_run {#TestSuite::should_run} + +`bool TestSuite::should_run() const` + +Returns true if any test in this test suite should run. + +##### successful_test_count {#TestSuite::successful_test_count} + +`int TestSuite::successful_test_count() const` + +Gets the number of successful tests in this test suite. + +##### skipped_test_count {#TestSuite::skipped_test_count} + +`int TestSuite::skipped_test_count() const` + +Gets the number of skipped tests in this test suite. + +##### failed_test_count {#TestSuite::failed_test_count} + +`int TestSuite::failed_test_count() const` + +Gets the number of failed tests in this test suite. + +##### reportable_disabled_test_count {#TestSuite::reportable_disabled_test_count} + +`int TestSuite::reportable_disabled_test_count() const` + +Gets the number of disabled tests that will be reported in the XML report. + +##### disabled_test_count {#TestSuite::disabled_test_count} + +`int TestSuite::disabled_test_count() const` + +Gets the number of disabled tests in this test suite. + +##### reportable_test_count {#TestSuite::reportable_test_count} + +`int TestSuite::reportable_test_count() const` + +Gets the number of tests to be printed in the XML report. + +##### test_to_run_count {#TestSuite::test_to_run_count} + +`int TestSuite::test_to_run_count() const` + +Get the number of tests in this test suite that should run. + +##### total_test_count {#TestSuite::total_test_count} + +`int TestSuite::total_test_count() const` + +Gets the number of all tests in this test suite. + +##### Passed {#TestSuite::Passed} + +`bool TestSuite::Passed() const` + +Returns true if and only if the test suite passed. + +##### Failed {#TestSuite::Failed} + +`bool TestSuite::Failed() const` + +Returns true if and only if the test suite failed. + +##### elapsed_time {#TestSuite::elapsed_time} + +`TimeInMillis TestSuite::elapsed_time() const` + +Returns the elapsed time, in milliseconds. + +##### start_timestamp {#TestSuite::start_timestamp} + +`TimeInMillis TestSuite::start_timestamp() const` + +Gets the time of the test suite start, in ms from the start of the UNIX epoch. + +##### GetTestInfo {#TestSuite::GetTestInfo} + +`const TestInfo* TestSuite::GetTestInfo(int i) const` + +Returns the [`TestInfo`](#TestInfo) for the `i`-th test among all the tests. `i` +can range from 0 to `total_test_count() - 1`. If `i` is not in that range, +returns `NULL`. + +##### ad_hoc_test_result {#TestSuite::ad_hoc_test_result} + +`const TestResult& TestSuite::ad_hoc_test_result() const` + +Returns the [`TestResult`](#TestResult) that holds test properties recorded +during execution of `SetUpTestSuite` and `TearDownTestSuite`. + +### TestInfo {#TestInfo} + +`::testing::TestInfo` + +Stores information about a test. + +#### Public Methods {#TestInfo-public} + +##### test_suite_name {#TestInfo::test_suite_name} + +`const char* TestInfo::test_suite_name() const` + +Returns the test suite name. + +##### name {#TestInfo::name} + +`const char* TestInfo::name() const` + +Returns the test name. + +##### type_param {#TestInfo::type_param} + +`const char* TestInfo::type_param() const` + +Returns the name of the parameter type, or `NULL` if this is not a typed or +type-parameterized test. See [Typed Tests](../advanced.md#typed-tests) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests). + +##### value_param {#TestInfo::value_param} + +`const char* TestInfo::value_param() const` + +Returns the text representation of the value parameter, or `NULL` if this is not +a value-parameterized test. See +[Value-Parameterized Tests](../advanced.md#value-parameterized-tests). + +##### file {#TestInfo::file} + +`const char* TestInfo::file() const` + +Returns the file name where this test is defined. + +##### line {#TestInfo::line} + +`int TestInfo::line() const` + +Returns the line where this test is defined. + +##### is_in_another_shard {#TestInfo::is_in_another_shard} + +`bool TestInfo::is_in_another_shard() const` + +Returns true if this test should not be run because it's in another shard. + +##### should_run {#TestInfo::should_run} + +`bool TestInfo::should_run() const` + +Returns true if this test should run, that is if the test is not disabled (or it +is disabled but the `also_run_disabled_tests` flag has been specified) and its +full name matches the user-specified filter. + +GoogleTest allows the user to filter the tests by their full names. Only the +tests that match the filter will run. See +[Running a Subset of the Tests](../advanced.md#running-a-subset-of-the-tests) +for more information. + +##### is_reportable {#TestInfo::is_reportable} + +`bool TestInfo::is_reportable() const` + +Returns true if and only if this test will appear in the XML report. + +##### result {#TestInfo::result} + +`const TestResult* TestInfo::result() const` + +Returns the result of the test. See [`TestResult`](#TestResult). + +### TestParamInfo {#TestParamInfo} + +`::testing::TestParamInfo` + +Describes a parameter to a value-parameterized test. The type `T` is the type of +the parameter. + +Contains the fields `param` and `index` which hold the value of the parameter +and its integer index respectively. + +### UnitTest {#UnitTest} + +`::testing::UnitTest` + +This class contains information about the test program. + +`UnitTest` is a singleton class. The only instance is created when +`UnitTest::GetInstance()` is first called. This instance is never deleted. + +`UnitTest` is not copyable. + +#### Public Methods {#UnitTest-public} + +##### GetInstance {#UnitTest::GetInstance} + +`static UnitTest* UnitTest::GetInstance()` + +Gets the singleton `UnitTest` object. The first time this method is called, a +`UnitTest` object is constructed and returned. Consecutive calls will return the +same object. + +##### original_working_dir {#UnitTest::original_working_dir} + +`const char* UnitTest::original_working_dir() const` + +Returns the working directory when the first [`TEST()`](#TEST) or +[`TEST_F()`](#TEST_F) was executed. The `UnitTest` object owns the string. + +##### current_test_suite {#UnitTest::current_test_suite} + +`const TestSuite* UnitTest::current_test_suite() const` + +Returns the [`TestSuite`](#TestSuite) object for the test that's currently +running, or `NULL` if no test is running. + +##### current_test_info {#UnitTest::current_test_info} + +`const TestInfo* UnitTest::current_test_info() const` + +Returns the [`TestInfo`](#TestInfo) object for the test that's currently +running, or `NULL` if no test is running. + +##### random_seed {#UnitTest::random_seed} + +`int UnitTest::random_seed() const` + +Returns the random seed used at the start of the current test run. + +##### successful_test_suite_count {#UnitTest::successful_test_suite_count} + +`int UnitTest::successful_test_suite_count() const` + +Gets the number of successful test suites. + +##### failed_test_suite_count {#UnitTest::failed_test_suite_count} + +`int UnitTest::failed_test_suite_count() const` + +Gets the number of failed test suites. + +##### total_test_suite_count {#UnitTest::total_test_suite_count} + +`int UnitTest::total_test_suite_count() const` + +Gets the number of all test suites. + +##### test_suite_to_run_count {#UnitTest::test_suite_to_run_count} + +`int UnitTest::test_suite_to_run_count() const` + +Gets the number of all test suites that contain at least one test that should +run. + +##### successful_test_count {#UnitTest::successful_test_count} + +`int UnitTest::successful_test_count() const` + +Gets the number of successful tests. + +##### skipped_test_count {#UnitTest::skipped_test_count} + +`int UnitTest::skipped_test_count() const` + +Gets the number of skipped tests. + +##### failed_test_count {#UnitTest::failed_test_count} + +`int UnitTest::failed_test_count() const` + +Gets the number of failed tests. + +##### reportable_disabled_test_count {#UnitTest::reportable_disabled_test_count} + +`int UnitTest::reportable_disabled_test_count() const` + +Gets the number of disabled tests that will be reported in the XML report. + +##### disabled_test_count {#UnitTest::disabled_test_count} + +`int UnitTest::disabled_test_count() const` + +Gets the number of disabled tests. + +##### reportable_test_count {#UnitTest::reportable_test_count} + +`int UnitTest::reportable_test_count() const` + +Gets the number of tests to be printed in the XML report. + +##### total_test_count {#UnitTest::total_test_count} + +`int UnitTest::total_test_count() const` + +Gets the number of all tests. + +##### test_to_run_count {#UnitTest::test_to_run_count} + +`int UnitTest::test_to_run_count() const` + +Gets the number of tests that should run. + +##### start_timestamp {#UnitTest::start_timestamp} + +`TimeInMillis UnitTest::start_timestamp() const` + +Gets the time of the test program start, in ms from the start of the UNIX epoch. + +##### elapsed_time {#UnitTest::elapsed_time} + +`TimeInMillis UnitTest::elapsed_time() const` + +Gets the elapsed time, in milliseconds. + +##### Passed {#UnitTest::Passed} + +`bool UnitTest::Passed() const` + +Returns true if and only if the unit test passed (i.e. all test suites passed). + +##### Failed {#UnitTest::Failed} + +`bool UnitTest::Failed() const` + +Returns true if and only if the unit test failed (i.e. some test suite failed or +something outside of all tests failed). + +##### GetTestSuite {#UnitTest::GetTestSuite} + +`const TestSuite* UnitTest::GetTestSuite(int i) const` + +Gets the [`TestSuite`](#TestSuite) object for the `i`-th test suite among all +the test suites. `i` can range from 0 to `total_test_suite_count() - 1`. If `i` +is not in that range, returns `NULL`. + +##### ad_hoc_test_result {#UnitTest::ad_hoc_test_result} + +`const TestResult& UnitTest::ad_hoc_test_result() const` + +Returns the [`TestResult`](#TestResult) containing information on test failures +and properties logged outside of individual test suites. + +##### listeners {#UnitTest::listeners} + +`TestEventListeners& UnitTest::listeners()` + +Returns the list of event listeners that can be used to track events inside +GoogleTest. See [`TestEventListeners`](#TestEventListeners). + +### TestEventListener {#TestEventListener} + +`::testing::TestEventListener` + +The interface for tracing execution of tests. The methods below are listed in +the order the corresponding events are fired. + +#### Public Methods {#TestEventListener-public} + +##### OnTestProgramStart {#TestEventListener::OnTestProgramStart} + +`virtual void TestEventListener::OnTestProgramStart(const UnitTest& unit_test)` + +Fired before any test activity starts. + +##### OnTestIterationStart {#TestEventListener::OnTestIterationStart} + +`virtual void TestEventListener::OnTestIterationStart(const UnitTest& unit_test, +int iteration)` + +Fired before each iteration of tests starts. There may be more than one +iteration if `GTEST_FLAG(repeat)` is set. `iteration` is the iteration index, +starting from 0. + +##### OnEnvironmentsSetUpStart {#TestEventListener::OnEnvironmentsSetUpStart} + +`virtual void TestEventListener::OnEnvironmentsSetUpStart(const UnitTest& +unit_test)` + +Fired before environment set-up for each iteration of tests starts. + +##### OnEnvironmentsSetUpEnd {#TestEventListener::OnEnvironmentsSetUpEnd} + +`virtual void TestEventListener::OnEnvironmentsSetUpEnd(const UnitTest& +unit_test)` + +Fired after environment set-up for each iteration of tests ends. + +##### OnTestSuiteStart {#TestEventListener::OnTestSuiteStart} + +`virtual void TestEventListener::OnTestSuiteStart(const TestSuite& test_suite)` + +Fired before the test suite starts. + +##### OnTestStart {#TestEventListener::OnTestStart} + +`virtual void TestEventListener::OnTestStart(const TestInfo& test_info)` + +Fired before the test starts. + +##### OnTestPartResult {#TestEventListener::OnTestPartResult} + +`virtual void TestEventListener::OnTestPartResult(const TestPartResult& +test_part_result)` + +Fired after a failed assertion or a `SUCCEED()` invocation. If you want to throw +an exception from this function to skip to the next test, it must be an +[`AssertionException`](#AssertionException) or inherited from it. + +##### OnTestEnd {#TestEventListener::OnTestEnd} + +`virtual void TestEventListener::OnTestEnd(const TestInfo& test_info)` + +Fired after the test ends. + +##### OnTestSuiteEnd {#TestEventListener::OnTestSuiteEnd} + +`virtual void TestEventListener::OnTestSuiteEnd(const TestSuite& test_suite)` + +Fired after the test suite ends. + +##### OnEnvironmentsTearDownStart {#TestEventListener::OnEnvironmentsTearDownStart} + +`virtual void TestEventListener::OnEnvironmentsTearDownStart(const UnitTest& +unit_test)` + +Fired before environment tear-down for each iteration of tests starts. + +##### OnEnvironmentsTearDownEnd {#TestEventListener::OnEnvironmentsTearDownEnd} + +`virtual void TestEventListener::OnEnvironmentsTearDownEnd(const UnitTest& +unit_test)` + +Fired after environment tear-down for each iteration of tests ends. + +##### OnTestIterationEnd {#TestEventListener::OnTestIterationEnd} + +`virtual void TestEventListener::OnTestIterationEnd(const UnitTest& unit_test, +int iteration)` + +Fired after each iteration of tests finishes. + +##### OnTestProgramEnd {#TestEventListener::OnTestProgramEnd} + +`virtual void TestEventListener::OnTestProgramEnd(const UnitTest& unit_test)` + +Fired after all test activities have ended. + +### TestEventListeners {#TestEventListeners} + +`::testing::TestEventListeners` + +Lets users add listeners to track events in GoogleTest. + +#### Public Methods {#TestEventListeners-public} + +##### Append {#TestEventListeners::Append} + +`void TestEventListeners::Append(TestEventListener* listener)` + +Appends an event listener to the end of the list. GoogleTest assumes ownership +of the listener (i.e. it will delete the listener when the test program +finishes). + +##### Release {#TestEventListeners::Release} + +`TestEventListener* TestEventListeners::Release(TestEventListener* listener)` + +Removes the given event listener from the list and returns it. It then becomes +the caller's responsibility to delete the listener. Returns `NULL` if the +listener is not found in the list. + +##### default_result_printer {#TestEventListeners::default_result_printer} + +`TestEventListener* TestEventListeners::default_result_printer() const` + +Returns the standard listener responsible for the default console output. Can be +removed from the listeners list to shut down default console output. Note that +removing this object from the listener list with +[`Release()`](#TestEventListeners::Release) transfers its ownership to the +caller and makes this function return `NULL` the next time. + +##### default_xml_generator {#TestEventListeners::default_xml_generator} + +`TestEventListener* TestEventListeners::default_xml_generator() const` + +Returns the standard listener responsible for the default XML output controlled +by the `--gtest_output=xml` flag. Can be removed from the listeners list by +users who want to shut down the default XML output controlled by this flag and +substitute it with custom one. Note that removing this object from the listener +list with [`Release()`](#TestEventListeners::Release) transfers its ownership to +the caller and makes this function return `NULL` the next time. + +### TestPartResult {#TestPartResult} + +`::testing::TestPartResult` + +A copyable object representing the result of a test part (i.e. an assertion or +an explicit `FAIL()`, `ADD_FAILURE()`, or `SUCCESS()`). + +#### Public Methods {#TestPartResult-public} + +##### type {#TestPartResult::type} + +`Type TestPartResult::type() const` + +Gets the outcome of the test part. + +The return type `Type` is an enum defined as follows: + +```cpp +enum Type { + kSuccess, // Succeeded. + kNonFatalFailure, // Failed but the test can continue. + kFatalFailure, // Failed and the test should be terminated. + kSkip // Skipped. +}; +``` + +##### file_name {#TestPartResult::file_name} + +`const char* TestPartResult::file_name() const` + +Gets the name of the source file where the test part took place, or `NULL` if +it's unknown. + +##### line_number {#TestPartResult::line_number} + +`int TestPartResult::line_number() const` + +Gets the line in the source file where the test part took place, or `-1` if it's +unknown. + +##### summary {#TestPartResult::summary} + +`const char* TestPartResult::summary() const` + +Gets the summary of the failure message. + +##### message {#TestPartResult::message} + +`const char* TestPartResult::message() const` + +Gets the message associated with the test part. + +##### skipped {#TestPartResult::skipped} + +`bool TestPartResult::skipped() const` + +Returns true if and only if the test part was skipped. + +##### passed {#TestPartResult::passed} + +`bool TestPartResult::passed() const` + +Returns true if and only if the test part passed. + +##### nonfatally_failed {#TestPartResult::nonfatally_failed} + +`bool TestPartResult::nonfatally_failed() const` + +Returns true if and only if the test part non-fatally failed. + +##### fatally_failed {#TestPartResult::fatally_failed} + +`bool TestPartResult::fatally_failed() const` + +Returns true if and only if the test part fatally failed. + +##### failed {#TestPartResult::failed} + +`bool TestPartResult::failed() const` + +Returns true if and only if the test part failed. + +### TestProperty {#TestProperty} + +`::testing::TestProperty` + +A copyable object representing a user-specified test property which can be +output as a key/value string pair. + +#### Public Methods {#TestProperty-public} + +##### key {#key} + +`const char* key() const` + +Gets the user-supplied key. + +##### value {#value} + +`const char* value() const` + +Gets the user-supplied value. + +##### SetValue {#SetValue} + +`void SetValue(const std::string& new_value)` + +Sets a new value, overriding the previous one. + +### TestResult {#TestResult} + +`::testing::TestResult` + +Contains information about the result of a single test. + +`TestResult` is not copyable. + +#### Public Methods {#TestResult-public} + +##### total_part_count {#TestResult::total_part_count} + +`int TestResult::total_part_count() const` + +Gets the number of all test parts. This is the sum of the number of successful +test parts and the number of failed test parts. + +##### test_property_count {#TestResult::test_property_count} + +`int TestResult::test_property_count() const` + +Returns the number of test properties. + +##### Passed {#TestResult::Passed} + +`bool TestResult::Passed() const` + +Returns true if and only if the test passed (i.e. no test part failed). + +##### Skipped {#TestResult::Skipped} + +`bool TestResult::Skipped() const` + +Returns true if and only if the test was skipped. + +##### Failed {#TestResult::Failed} + +`bool TestResult::Failed() const` + +Returns true if and only if the test failed. + +##### HasFatalFailure {#TestResult::HasFatalFailure} + +`bool TestResult::HasFatalFailure() const` + +Returns true if and only if the test fatally failed. + +##### HasNonfatalFailure {#TestResult::HasNonfatalFailure} + +`bool TestResult::HasNonfatalFailure() const` + +Returns true if and only if the test has a non-fatal failure. + +##### elapsed_time {#TestResult::elapsed_time} + +`TimeInMillis TestResult::elapsed_time() const` + +Returns the elapsed time, in milliseconds. + +##### start_timestamp {#TestResult::start_timestamp} + +`TimeInMillis TestResult::start_timestamp() const` + +Gets the time of the test case start, in ms from the start of the UNIX epoch. + +##### GetTestPartResult {#TestResult::GetTestPartResult} + +`const TestPartResult& TestResult::GetTestPartResult(int i) const` + +Returns the [`TestPartResult`](#TestPartResult) for the `i`-th test part result +among all the results. `i` can range from 0 to `total_part_count() - 1`. If `i` +is not in that range, aborts the program. + +##### GetTestProperty {#TestResult::GetTestProperty} + +`const TestProperty& TestResult::GetTestProperty(int i) const` + +Returns the [`TestProperty`](#TestProperty) object for the `i`-th test property. +`i` can range from 0 to `test_property_count() - 1`. If `i` is not in that +range, aborts the program. + +### TimeInMillis {#TimeInMillis} + +`::testing::TimeInMillis` + +An integer type representing time in milliseconds. + +### Types {#Types} + +`::testing::Types` + +Represents a list of types for use in typed tests and type-parameterized tests. + +The template argument `T...` can be any number of types, for example: + +``` +::testing::Types +``` + +See [Typed Tests](../advanced.md#typed-tests) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests) for more +information. + +### WithParamInterface {#WithParamInterface} + +`::testing::WithParamInterface` + +The pure interface class that all value-parameterized tests inherit from. + +A value-parameterized test fixture class must inherit from both [`Test`](#Test) +and `WithParamInterface`. In most cases that just means inheriting from +[`TestWithParam`](#TestWithParam), but more complicated test hierarchies may +need to inherit from `Test` and `WithParamInterface` at different levels. + +This interface defines the type alias `ParamType` for the parameter type `T` and +has support for accessing the test parameter value via the `GetParam()` method: + +``` +static const ParamType& GetParam() +``` + +For more information, see +[Value-Parameterized Tests](../advanced.md#value-parameterized-tests). + +## Functions + +GoogleTest defines the following functions to help with writing and running +tests. + +### InitGoogleTest {#InitGoogleTest} + +`void ::testing::InitGoogleTest(int* argc, char** argv)` \ +`void ::testing::InitGoogleTest(int* argc, wchar_t** argv)` \ +`void ::testing::InitGoogleTest()` + +Initializes GoogleTest. This must be called before calling +[`RUN_ALL_TESTS()`](#RUN_ALL_TESTS). In particular, it parses the command line +for the flags that GoogleTest recognizes. Whenever a GoogleTest flag is seen, it +is removed from `argv`, and `*argc` is decremented. + +No value is returned. Instead, the GoogleTest flag variables are updated. + +The `InitGoogleTest(int* argc, wchar_t** argv)` overload can be used in Windows +programs compiled in `UNICODE` mode. + +The argument-less `InitGoogleTest()` overload can be used on Arduino/embedded +platforms where there is no `argc`/`argv`. + +### AddGlobalTestEnvironment {#AddGlobalTestEnvironment} + +`Environment* ::testing::AddGlobalTestEnvironment(Environment* env)` + +Adds a test environment to the test program. Must be called before +[`RUN_ALL_TESTS()`](#RUN_ALL_TESTS) is called. See +[Global Set-Up and Tear-Down](../advanced.md#global-set-up-and-tear-down) for +more information. + +See also [`Environment`](#Environment). + +### RegisterTest {#RegisterTest} + +```cpp +template +TestInfo* ::testing::RegisterTest(const char* test_suite_name, const char* test_name, + const char* type_param, const char* value_param, + const char* file, int line, Factory factory) +``` + +Dynamically registers a test with the framework. + +The `factory` argument is a factory callable (move-constructible) object or +function pointer that creates a new instance of the `Test` object. It handles +ownership to the caller. The signature of the callable is `Fixture*()`, where +`Fixture` is the test fixture class for the test. All tests registered with the +same `test_suite_name` must return the same fixture type. This is checked at +runtime. + +The framework will infer the fixture class from the factory and will call the +`SetUpTestSuite` and `TearDownTestSuite` methods for it. + +Must be called before [`RUN_ALL_TESTS()`](#RUN_ALL_TESTS) is invoked, otherwise +behavior is undefined. + +See +[Registering tests programmatically](../advanced.md#registering-tests-programmatically) +for more information. + +### RUN_ALL_TESTS {#RUN_ALL_TESTS} + +`int RUN_ALL_TESTS()` + +Use this function in `main()` to run all tests. It returns `0` if all tests are +successful, or `1` otherwise. + +`RUN_ALL_TESTS()` should be invoked after the command line has been parsed by +[`InitGoogleTest()`](#InitGoogleTest). + +This function was formerly a macro; thus, it is in the global namespace and has +an all-caps name. + +### AssertionSuccess {#AssertionSuccess} + +`AssertionResult ::testing::AssertionSuccess()` + +Creates a successful assertion result. See +[`AssertionResult`](#AssertionResult). + +### AssertionFailure {#AssertionFailure} + +`AssertionResult ::testing::AssertionFailure()` + +Creates a failed assertion result. Use the `<<` operator to store a failure +message: + +```cpp +::testing::AssertionFailure() << "My failure message"; +``` + +See [`AssertionResult`](#AssertionResult). + +### StaticAssertTypeEq {#StaticAssertTypeEq} + +`::testing::StaticAssertTypeEq()` + +Compile-time assertion for type equality. Compiles if and only if `T1` and `T2` +are the same type. The value it returns is irrelevant. + +See [Type Assertions](../advanced.md#type-assertions) for more information. + +### PrintToString {#PrintToString} + +`std::string ::testing::PrintToString(x)` + +Prints any value `x` using GoogleTest's value printer. + +See +[Teaching GoogleTest How to Print Your Values](../advanced.md#teaching-googletest-how-to-print-your-values) +for more information. + +### PrintToStringParamName {#PrintToStringParamName} + +`std::string ::testing::PrintToStringParamName(TestParamInfo& info)` + +A built-in parameterized test name generator which returns the result of +[`PrintToString`](#PrintToString) called on `info.param`. Does not work when the +test parameter is a `std::string` or C string. See +[Specifying Names for Value-Parameterized Test Parameters](../advanced.md#specifying-names-for-value-parameterized-test-parameters) +for more information. + +See also [`TestParamInfo`](#TestParamInfo) and +[`INSTANTIATE_TEST_SUITE_P`](#INSTANTIATE_TEST_SUITE_P). diff --git a/3rdparty/googletest-1.13.0/docs/samples.md b/3rdparty/googletest-1.13.0/docs/samples.md new file mode 100644 index 0000000000000000000000000000000000000000..dedc59098df5ae6a318bae5992694dc11b6daf62 --- /dev/null +++ b/3rdparty/googletest-1.13.0/docs/samples.md @@ -0,0 +1,22 @@ +# Googletest Samples + +If you're like us, you'd like to look at +[googletest samples.](https://github.com/google/googletest/blob/main/googletest/samples) +The sample directory has a number of well-commented samples showing how to use a +variety of googletest features. + +* Sample #1 shows the basic steps of using googletest to test C++ functions. +* Sample #2 shows a more complex unit test for a class with multiple member + functions. +* Sample #3 uses a test fixture. +* Sample #4 teaches you how to use googletest and `googletest.h` together to + get the best of both libraries. +* Sample #5 puts shared testing logic in a base test fixture, and reuses it in + derived fixtures. +* Sample #6 demonstrates type-parameterized tests. +* Sample #7 teaches the basics of value-parameterized tests. +* Sample #8 shows using `Combine()` in value-parameterized tests. +* Sample #9 shows use of the listener API to modify Google Test's console + output and the use of its reflection API to inspect test results. +* Sample #10 shows use of the listener API to implement a primitive memory + leak checker. diff --git a/3rdparty/googletest-1.13.0/googlemock/README.md b/3rdparty/googletest-1.13.0/googlemock/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7da60655dba8b8e91ec66a9a65f97139af03ee9b --- /dev/null +++ b/3rdparty/googletest-1.13.0/googlemock/README.md @@ -0,0 +1,40 @@ +# Googletest Mocking (gMock) Framework + +### Overview + +Google's framework for writing and using C++ mock classes. It can help you +derive better designs of your system and write better tests. + +It is inspired by: + +* [jMock](http://www.jmock.org/) +* [EasyMock](http://www.easymock.org/) +* [Hamcrest](http://code.google.com/p/hamcrest/) + +It is designed with C++'s specifics in mind. + +gMock: + +- Provides a declarative syntax for defining mocks. +- Can define partial (hybrid) mocks, which are a cross of real and mock + objects. +- Handles functions of arbitrary types and overloaded functions. +- Comes with a rich set of matchers for validating function arguments. +- Uses an intuitive syntax for controlling the behavior of a mock. +- Does automatic verification of expectations (no record-and-replay needed). +- Allows arbitrary (partial) ordering constraints on function calls to be + expressed. +- Lets a user extend it by defining new matchers and actions. +- Does not use exceptions. +- Is easy to learn and use. + +Details and examples can be found here: + +* [gMock for Dummies](https://google.github.io/googletest/gmock_for_dummies.html) +* [Legacy gMock FAQ](https://google.github.io/googletest/gmock_faq.html) +* [gMock Cookbook](https://google.github.io/googletest/gmock_cook_book.html) +* [gMock Cheat Sheet](https://google.github.io/googletest/gmock_cheat_sheet.html) + +GoogleMock is a part of +[GoogleTest C++ testing framework](http://github.com/google/googletest/) and a +subject to the same requirements. diff --git a/3rdparty/googletest-1.13.0/googlemock/cmake/gmock.pc.in b/3rdparty/googletest-1.13.0/googlemock/cmake/gmock.pc.in new file mode 100644 index 0000000000000000000000000000000000000000..23c67b5c88db4add6d21403b8ecbaf1be5a88813 --- /dev/null +++ b/3rdparty/googletest-1.13.0/googlemock/cmake/gmock.pc.in @@ -0,0 +1,10 @@ +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: gmock +Description: GoogleMock (without main() function) +Version: @PROJECT_VERSION@ +URL: https://github.com/google/googletest +Requires: gtest = @PROJECT_VERSION@ +Libs: -L${libdir} -lgmock @CMAKE_THREAD_LIBS_INIT@ +Cflags: -I${includedir} @GTEST_HAS_PTHREAD_MACRO@ diff --git a/3rdparty/googletest-1.13.0/googlemock/cmake/gmock_main.pc.in b/3rdparty/googletest-1.13.0/googlemock/cmake/gmock_main.pc.in new file mode 100644 index 0000000000000000000000000000000000000000..66ffea7f4431f606c5ca5d87bef505157658244d --- /dev/null +++ b/3rdparty/googletest-1.13.0/googlemock/cmake/gmock_main.pc.in @@ -0,0 +1,10 @@ +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: gmock_main +Description: GoogleMock (with main() function) +Version: @PROJECT_VERSION@ +URL: https://github.com/google/googletest +Requires: gmock = @PROJECT_VERSION@ +Libs: -L${libdir} -lgmock_main @CMAKE_THREAD_LIBS_INIT@ +Cflags: -I${includedir} @GTEST_HAS_PTHREAD_MACRO@ diff --git a/3rdparty/googletest-1.13.0/googlemock/docs/README.md b/3rdparty/googletest-1.13.0/googlemock/docs/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1bc57b799cce933c034c31859594ca1b87689aef --- /dev/null +++ b/3rdparty/googletest-1.13.0/googlemock/docs/README.md @@ -0,0 +1,4 @@ +# Content Moved + +We are working on updates to the GoogleTest documentation, which has moved to +the top-level [docs](../../docs) directory. diff --git a/3rdparty/googletest-1.13.0/googlemock/include/gmock/gmock-actions.h b/3rdparty/googletest-1.13.0/googlemock/include/gmock/gmock-actions.h new file mode 100644 index 0000000000000000000000000000000000000000..aad07d51cc1a26f0b55d12633c374b45b2eff49d --- /dev/null +++ b/3rdparty/googletest-1.13.0/googlemock/include/gmock/gmock-actions.h @@ -0,0 +1,2302 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Google Mock - a framework for writing C++ mock classes. +// +// The ACTION* family of macros can be used in a namespace scope to +// define custom actions easily. The syntax: +// +// ACTION(name) { statements; } +// +// will define an action with the given name that executes the +// statements. The value returned by the statements will be used as +// the return value of the action. Inside the statements, you can +// refer to the K-th (0-based) argument of the mock function by +// 'argK', and refer to its type by 'argK_type'. For example: +// +// ACTION(IncrementArg1) { +// arg1_type temp = arg1; +// return ++(*temp); +// } +// +// allows you to write +// +// ...WillOnce(IncrementArg1()); +// +// You can also refer to the entire argument tuple and its type by +// 'args' and 'args_type', and refer to the mock function type and its +// return type by 'function_type' and 'return_type'. +// +// Note that you don't need to specify the types of the mock function +// arguments. However rest assured that your code is still type-safe: +// you'll get a compiler error if *arg1 doesn't support the ++ +// operator, or if the type of ++(*arg1) isn't compatible with the +// mock function's return type, for example. +// +// Sometimes you'll want to parameterize the action. For that you can use +// another macro: +// +// ACTION_P(name, param_name) { statements; } +// +// For example: +// +// ACTION_P(Add, n) { return arg0 + n; } +// +// will allow you to write: +// +// ...WillOnce(Add(5)); +// +// Note that you don't need to provide the type of the parameter +// either. If you need to reference the type of a parameter named +// 'foo', you can write 'foo_type'. For example, in the body of +// ACTION_P(Add, n) above, you can write 'n_type' to refer to the type +// of 'n'. +// +// We also provide ACTION_P2, ACTION_P3, ..., up to ACTION_P10 to support +// multi-parameter actions. +// +// For the purpose of typing, you can view +// +// ACTION_Pk(Foo, p1, ..., pk) { ... } +// +// as shorthand for +// +// template +// FooActionPk Foo(p1_type p1, ..., pk_type pk) { ... } +// +// In particular, you can provide the template type arguments +// explicitly when invoking Foo(), as in Foo(5, false); +// although usually you can rely on the compiler to infer the types +// for you automatically. You can assign the result of expression +// Foo(p1, ..., pk) to a variable of type FooActionPk. This can be useful when composing actions. +// +// You can also overload actions with different numbers of parameters: +// +// ACTION_P(Plus, a) { ... } +// ACTION_P2(Plus, a, b) { ... } +// +// While it's tempting to always use the ACTION* macros when defining +// a new action, you should also consider implementing ActionInterface +// or using MakePolymorphicAction() instead, especially if you need to +// use the action a lot. While these approaches require more work, +// they give you more control on the types of the mock function +// arguments and the action parameters, which in general leads to +// better compiler error messages that pay off in the long run. They +// also allow overloading actions based on parameter types (as opposed +// to just based on the number of parameters). +// +// CAVEAT: +// +// ACTION*() can only be used in a namespace scope as templates cannot be +// declared inside of a local class. +// Users can, however, define any local functors (e.g. a lambda) that +// can be used as actions. +// +// MORE INFORMATION: +// +// To learn more about using these macros, please search for 'ACTION' on +// https://github.com/google/googletest/blob/main/docs/gmock_cook_book.md + +// IWYU pragma: private, include "gmock/gmock.h" +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_ACTIONS_H_ +#define GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_ACTIONS_H_ + +#ifndef _WIN32_WCE +#include +#endif + +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/internal/gmock-internal-utils.h" +#include "gmock/internal/gmock-port.h" +#include "gmock/internal/gmock-pp.h" + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + +namespace testing { + +// To implement an action Foo, define: +// 1. a class FooAction that implements the ActionInterface interface, and +// 2. a factory function that creates an Action object from a +// const FooAction*. +// +// The two-level delegation design follows that of Matcher, providing +// consistency for extension developers. It also eases ownership +// management as Action objects can now be copied like plain values. + +namespace internal { + +// BuiltInDefaultValueGetter::Get() returns a +// default-constructed T value. BuiltInDefaultValueGetter::Get() crashes with an error. +// +// This primary template is used when kDefaultConstructible is true. +template +struct BuiltInDefaultValueGetter { + static T Get() { return T(); } +}; +template +struct BuiltInDefaultValueGetter { + static T Get() { + Assert(false, __FILE__, __LINE__, + "Default action undefined for the function return type."); + return internal::Invalid(); + // The above statement will never be reached, but is required in + // order for this function to compile. + } +}; + +// BuiltInDefaultValue::Get() returns the "built-in" default value +// for type T, which is NULL when T is a raw pointer type, 0 when T is +// a numeric type, false when T is bool, or "" when T is string or +// std::string. In addition, in C++11 and above, it turns a +// default-constructed T value if T is default constructible. For any +// other type T, the built-in default T value is undefined, and the +// function will abort the process. +template +class BuiltInDefaultValue { + public: + // This function returns true if and only if type T has a built-in default + // value. + static bool Exists() { return ::std::is_default_constructible::value; } + + static T Get() { + return BuiltInDefaultValueGetter< + T, ::std::is_default_constructible::value>::Get(); + } +}; + +// This partial specialization says that we use the same built-in +// default value for T and const T. +template +class BuiltInDefaultValue { + public: + static bool Exists() { return BuiltInDefaultValue::Exists(); } + static T Get() { return BuiltInDefaultValue::Get(); } +}; + +// This partial specialization defines the default values for pointer +// types. +template +class BuiltInDefaultValue { + public: + static bool Exists() { return true; } + static T* Get() { return nullptr; } +}; + +// The following specializations define the default values for +// specific types we care about. +#define GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(type, value) \ + template <> \ + class BuiltInDefaultValue { \ + public: \ + static bool Exists() { return true; } \ + static type Get() { return value; } \ + } + +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(void, ); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(::std::string, ""); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(bool, false); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(unsigned char, '\0'); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(signed char, '\0'); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(char, '\0'); + +// There's no need for a default action for signed wchar_t, as that +// type is the same as wchar_t for gcc, and invalid for MSVC. +// +// There's also no need for a default action for unsigned wchar_t, as +// that type is the same as unsigned int for gcc, and invalid for +// MSVC. +#if GMOCK_WCHAR_T_IS_NATIVE_ +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(wchar_t, 0U); // NOLINT +#endif + +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(unsigned short, 0U); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(signed short, 0); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(unsigned int, 0U); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(signed int, 0); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(unsigned long, 0UL); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(signed long, 0L); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(unsigned long long, 0); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(signed long long, 0); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(float, 0); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(double, 0); + +#undef GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_ + +// Partial implementations of metaprogramming types from the standard library +// not available in C++11. + +template +struct negation + // NOLINTNEXTLINE + : std::integral_constant {}; + +// Base case: with zero predicates the answer is always true. +template +struct conjunction : std::true_type {}; + +// With a single predicate, the answer is that predicate. +template +struct conjunction : P1 {}; + +// With multiple predicates the answer is the first predicate if that is false, +// and we recurse otherwise. +template +struct conjunction + : std::conditional, P1>::type {}; + +template +struct disjunction : std::false_type {}; + +template +struct disjunction : P1 {}; + +template +struct disjunction + // NOLINTNEXTLINE + : std::conditional, P1>::type {}; + +template +using void_t = void; + +// Detects whether an expression of type `From` can be implicitly converted to +// `To` according to [conv]. In C++17, [conv]/3 defines this as follows: +// +// An expression e can be implicitly converted to a type T if and only if +// the declaration T t=e; is well-formed, for some invented temporary +// variable t ([dcl.init]). +// +// [conv]/2 implies we can use function argument passing to detect whether this +// initialization is valid. +// +// Note that this is distinct from is_convertible, which requires this be valid: +// +// To test() { +// return declval(); +// } +// +// In particular, is_convertible doesn't give the correct answer when `To` and +// `From` are the same non-moveable type since `declval` will be an rvalue +// reference, defeating the guaranteed copy elision that would otherwise make +// this function work. +// +// REQUIRES: `From` is not cv void. +template +struct is_implicitly_convertible { + private: + // A function that accepts a parameter of type T. This can be called with type + // U successfully only if U is implicitly convertible to T. + template + static void Accept(T); + + // A function that creates a value of type T. + template + static T Make(); + + // An overload be selected when implicit conversion from T to To is possible. + template (Make()))> + static std::true_type TestImplicitConversion(int); + + // A fallback overload selected in all other cases. + template + static std::false_type TestImplicitConversion(...); + + public: + using type = decltype(TestImplicitConversion(0)); + static constexpr bool value = type::value; +}; + +// Like std::invoke_result_t from C++17, but works only for objects with call +// operators (not e.g. member function pointers, which we don't need specific +// support for in OnceAction because std::function deals with them). +template +using call_result_t = decltype(std::declval()(std::declval()...)); + +template +struct is_callable_r_impl : std::false_type {}; + +// Specialize the struct for those template arguments where call_result_t is +// well-formed. When it's not, the generic template above is chosen, resulting +// in std::false_type. +template +struct is_callable_r_impl>, R, F, Args...> + : std::conditional< + std::is_void::value, // + std::true_type, // + is_implicitly_convertible, R>>::type {}; + +// Like std::is_invocable_r from C++17, but works only for objects with call +// operators. See the note on call_result_t. +template +using is_callable_r = is_callable_r_impl; + +// Like std::as_const from C++17. +template +typename std::add_const::type& as_const(T& t) { + return t; +} + +} // namespace internal + +// Specialized for function types below. +template +class OnceAction; + +// An action that can only be used once. +// +// This is accepted by WillOnce, which doesn't require the underlying action to +// be copy-constructible (only move-constructible), and promises to invoke it as +// an rvalue reference. This allows the action to work with move-only types like +// std::move_only_function in a type-safe manner. +// +// For example: +// +// // Assume we have some API that needs to accept a unique pointer to some +// // non-copyable object Foo. +// void AcceptUniquePointer(std::unique_ptr foo); +// +// // We can define an action that provides a Foo to that API. Because It +// // has to give away its unique pointer, it must not be called more than +// // once, so its call operator is &&-qualified. +// struct ProvideFoo { +// std::unique_ptr foo; +// +// void operator()() && { +// AcceptUniquePointer(std::move(Foo)); +// } +// }; +// +// // This action can be used with WillOnce. +// EXPECT_CALL(mock, Call) +// .WillOnce(ProvideFoo{std::make_unique(...)}); +// +// // But a call to WillRepeatedly will fail to compile. This is correct, +// // since the action cannot correctly be used repeatedly. +// EXPECT_CALL(mock, Call) +// .WillRepeatedly(ProvideFoo{std::make_unique(...)}); +// +// A less-contrived example would be an action that returns an arbitrary type, +// whose &&-qualified call operator is capable of dealing with move-only types. +template +class OnceAction final { + private: + // True iff we can use the given callable type (or lvalue reference) directly + // via StdFunctionAdaptor. + template + using IsDirectlyCompatible = internal::conjunction< + // It must be possible to capture the callable in StdFunctionAdaptor. + std::is_constructible::type, Callable>, + // The callable must be compatible with our signature. + internal::is_callable_r::type, + Args...>>; + + // True iff we can use the given callable type via StdFunctionAdaptor once we + // ignore incoming arguments. + template + using IsCompatibleAfterIgnoringArguments = internal::conjunction< + // It must be possible to capture the callable in a lambda. + std::is_constructible::type, Callable>, + // The callable must be invocable with zero arguments, returning something + // convertible to Result. + internal::is_callable_r::type>>; + + public: + // Construct from a callable that is directly compatible with our mocked + // signature: it accepts our function type's arguments and returns something + // convertible to our result type. + template ::type>>, + IsDirectlyCompatible> // + ::value, + int>::type = 0> + OnceAction(Callable&& callable) // NOLINT + : function_(StdFunctionAdaptor::type>( + {}, std::forward(callable))) {} + + // As above, but for a callable that ignores the mocked function's arguments. + template ::type>>, + // Exclude callables for which the overload above works. + // We'd rather provide the arguments if possible. + internal::negation>, + IsCompatibleAfterIgnoringArguments>::value, + int>::type = 0> + OnceAction(Callable&& callable) // NOLINT + // Call the constructor above with a callable + // that ignores the input arguments. + : OnceAction(IgnoreIncomingArguments::type>{ + std::forward(callable)}) {} + + // We are naturally copyable because we store only an std::function, but + // semantically we should not be copyable. + OnceAction(const OnceAction&) = delete; + OnceAction& operator=(const OnceAction&) = delete; + OnceAction(OnceAction&&) = default; + + // Invoke the underlying action callable with which we were constructed, + // handing it the supplied arguments. + Result Call(Args... args) && { + return function_(std::forward(args)...); + } + + private: + // An adaptor that wraps a callable that is compatible with our signature and + // being invoked as an rvalue reference so that it can be used as an + // StdFunctionAdaptor. This throws away type safety, but that's fine because + // this is only used by WillOnce, which we know calls at most once. + // + // Once we have something like std::move_only_function from C++23, we can do + // away with this. + template + class StdFunctionAdaptor final { + public: + // A tag indicating that the (otherwise universal) constructor is accepting + // the callable itself, instead of e.g. stealing calls for the move + // constructor. + struct CallableTag final {}; + + template + explicit StdFunctionAdaptor(CallableTag, F&& callable) + : callable_(std::make_shared(std::forward(callable))) {} + + // Rather than explicitly returning Result, we return whatever the wrapped + // callable returns. This allows for compatibility with existing uses like + // the following, when the mocked function returns void: + // + // EXPECT_CALL(mock_fn_, Call) + // .WillOnce([&] { + // [...] + // return 0; + // }); + // + // Such a callable can be turned into std::function. If we use an + // explicit return type of Result here then it *doesn't* work with + // std::function, because we'll get a "void function should not return a + // value" error. + // + // We need not worry about incompatible result types because the SFINAE on + // OnceAction already checks this for us. std::is_invocable_r_v itself makes + // the same allowance for void result types. + template + internal::call_result_t operator()( + ArgRefs&&... args) const { + return std::move(*callable_)(std::forward(args)...); + } + + private: + // We must put the callable on the heap so that we are copyable, which + // std::function needs. + std::shared_ptr callable_; + }; + + // An adaptor that makes a callable that accepts zero arguments callable with + // our mocked arguments. + template + struct IgnoreIncomingArguments { + internal::call_result_t operator()(Args&&...) { + return std::move(callable)(); + } + + Callable callable; + }; + + std::function function_; +}; + +// When an unexpected function call is encountered, Google Mock will +// let it return a default value if the user has specified one for its +// return type, or if the return type has a built-in default value; +// otherwise Google Mock won't know what value to return and will have +// to abort the process. +// +// The DefaultValue class allows a user to specify the +// default value for a type T that is both copyable and publicly +// destructible (i.e. anything that can be used as a function return +// type). The usage is: +// +// // Sets the default value for type T to be foo. +// DefaultValue::Set(foo); +template +class DefaultValue { + public: + // Sets the default value for type T; requires T to be + // copy-constructable and have a public destructor. + static void Set(T x) { + delete producer_; + producer_ = new FixedValueProducer(x); + } + + // Provides a factory function to be called to generate the default value. + // This method can be used even if T is only move-constructible, but it is not + // limited to that case. + typedef T (*FactoryFunction)(); + static void SetFactory(FactoryFunction factory) { + delete producer_; + producer_ = new FactoryValueProducer(factory); + } + + // Unsets the default value for type T. + static void Clear() { + delete producer_; + producer_ = nullptr; + } + + // Returns true if and only if the user has set the default value for type T. + static bool IsSet() { return producer_ != nullptr; } + + // Returns true if T has a default return value set by the user or there + // exists a built-in default value. + static bool Exists() { + return IsSet() || internal::BuiltInDefaultValue::Exists(); + } + + // Returns the default value for type T if the user has set one; + // otherwise returns the built-in default value. Requires that Exists() + // is true, which ensures that the return value is well-defined. + static T Get() { + return producer_ == nullptr ? internal::BuiltInDefaultValue::Get() + : producer_->Produce(); + } + + private: + class ValueProducer { + public: + virtual ~ValueProducer() {} + virtual T Produce() = 0; + }; + + class FixedValueProducer : public ValueProducer { + public: + explicit FixedValueProducer(T value) : value_(value) {} + T Produce() override { return value_; } + + private: + const T value_; + FixedValueProducer(const FixedValueProducer&) = delete; + FixedValueProducer& operator=(const FixedValueProducer&) = delete; + }; + + class FactoryValueProducer : public ValueProducer { + public: + explicit FactoryValueProducer(FactoryFunction factory) + : factory_(factory) {} + T Produce() override { return factory_(); } + + private: + const FactoryFunction factory_; + FactoryValueProducer(const FactoryValueProducer&) = delete; + FactoryValueProducer& operator=(const FactoryValueProducer&) = delete; + }; + + static ValueProducer* producer_; +}; + +// This partial specialization allows a user to set default values for +// reference types. +template +class DefaultValue { + public: + // Sets the default value for type T&. + static void Set(T& x) { // NOLINT + address_ = &x; + } + + // Unsets the default value for type T&. + static void Clear() { address_ = nullptr; } + + // Returns true if and only if the user has set the default value for type T&. + static bool IsSet() { return address_ != nullptr; } + + // Returns true if T has a default return value set by the user or there + // exists a built-in default value. + static bool Exists() { + return IsSet() || internal::BuiltInDefaultValue::Exists(); + } + + // Returns the default value for type T& if the user has set one; + // otherwise returns the built-in default value if there is one; + // otherwise aborts the process. + static T& Get() { + return address_ == nullptr ? internal::BuiltInDefaultValue::Get() + : *address_; + } + + private: + static T* address_; +}; + +// This specialization allows DefaultValue::Get() to +// compile. +template <> +class DefaultValue { + public: + static bool Exists() { return true; } + static void Get() {} +}; + +// Points to the user-set default value for type T. +template +typename DefaultValue::ValueProducer* DefaultValue::producer_ = nullptr; + +// Points to the user-set default value for type T&. +template +T* DefaultValue::address_ = nullptr; + +// Implement this interface to define an action for function type F. +template +class ActionInterface { + public: + typedef typename internal::Function::Result Result; + typedef typename internal::Function::ArgumentTuple ArgumentTuple; + + ActionInterface() {} + virtual ~ActionInterface() {} + + // Performs the action. This method is not const, as in general an + // action can have side effects and be stateful. For example, a + // get-the-next-element-from-the-collection action will need to + // remember the current element. + virtual Result Perform(const ArgumentTuple& args) = 0; + + private: + ActionInterface(const ActionInterface&) = delete; + ActionInterface& operator=(const ActionInterface&) = delete; +}; + +template +class Action; + +// An Action is a copyable and IMMUTABLE (except by assignment) +// object that represents an action to be taken when a mock function of type +// R(Args...) is called. The implementation of Action is just a +// std::shared_ptr to const ActionInterface. Don't inherit from Action! You +// can view an object implementing ActionInterface as a concrete action +// (including its current state), and an Action object as a handle to it. +template +class Action { + private: + using F = R(Args...); + + // Adapter class to allow constructing Action from a legacy ActionInterface. + // New code should create Actions from functors instead. + struct ActionAdapter { + // Adapter must be copyable to satisfy std::function requirements. + ::std::shared_ptr> impl_; + + template + typename internal::Function::Result operator()(InArgs&&... args) { + return impl_->Perform( + ::std::forward_as_tuple(::std::forward(args)...)); + } + }; + + template + using IsCompatibleFunctor = std::is_constructible, G>; + + public: + typedef typename internal::Function::Result Result; + typedef typename internal::Function::ArgumentTuple ArgumentTuple; + + // Constructs a null Action. Needed for storing Action objects in + // STL containers. + Action() {} + + // Construct an Action from a specified callable. + // This cannot take std::function directly, because then Action would not be + // directly constructible from lambda (it would require two conversions). + template < + typename G, + typename = typename std::enable_if, std::is_constructible, + G>>::value>::type> + Action(G&& fun) { // NOLINT + Init(::std::forward(fun), IsCompatibleFunctor()); + } + + // Constructs an Action from its implementation. + explicit Action(ActionInterface* impl) + : fun_(ActionAdapter{::std::shared_ptr>(impl)}) {} + + // This constructor allows us to turn an Action object into an + // Action, as long as F's arguments can be implicitly converted + // to Func's and Func's return type can be implicitly converted to F's. + template + Action(const Action& action) // NOLINT + : fun_(action.fun_) {} + + // Returns true if and only if this is the DoDefault() action. + bool IsDoDefault() const { return fun_ == nullptr; } + + // Performs the action. Note that this method is const even though + // the corresponding method in ActionInterface is not. The reason + // is that a const Action means that it cannot be re-bound to + // another concrete action, not that the concrete action it binds to + // cannot change state. (Think of the difference between a const + // pointer and a pointer to const.) + Result Perform(ArgumentTuple args) const { + if (IsDoDefault()) { + internal::IllegalDoDefault(__FILE__, __LINE__); + } + return internal::Apply(fun_, ::std::move(args)); + } + + // An action can be used as a OnceAction, since it's obviously safe to call it + // once. + operator OnceAction() const { // NOLINT + // Return a OnceAction-compatible callable that calls Perform with the + // arguments it is provided. We could instead just return fun_, but then + // we'd need to handle the IsDoDefault() case separately. + struct OA { + Action action; + + R operator()(Args... args) && { + return action.Perform( + std::forward_as_tuple(std::forward(args)...)); + } + }; + + return OA{*this}; + } + + private: + template + friend class Action; + + template + void Init(G&& g, ::std::true_type) { + fun_ = ::std::forward(g); + } + + template + void Init(G&& g, ::std::false_type) { + fun_ = IgnoreArgs::type>{::std::forward(g)}; + } + + template + struct IgnoreArgs { + template + Result operator()(const InArgs&...) const { + return function_impl(); + } + + FunctionImpl function_impl; + }; + + // fun_ is an empty function if and only if this is the DoDefault() action. + ::std::function fun_; +}; + +// The PolymorphicAction class template makes it easy to implement a +// polymorphic action (i.e. an action that can be used in mock +// functions of than one type, e.g. Return()). +// +// To define a polymorphic action, a user first provides a COPYABLE +// implementation class that has a Perform() method template: +// +// class FooAction { +// public: +// template +// Result Perform(const ArgumentTuple& args) const { +// // Processes the arguments and returns a result, using +// // std::get(args) to get the N-th (0-based) argument in the tuple. +// } +// ... +// }; +// +// Then the user creates the polymorphic action using +// MakePolymorphicAction(object) where object has type FooAction. See +// the definition of Return(void) and SetArgumentPointee(value) for +// complete examples. +template +class PolymorphicAction { + public: + explicit PolymorphicAction(const Impl& impl) : impl_(impl) {} + + template + operator Action() const { + return Action(new MonomorphicImpl(impl_)); + } + + private: + template + class MonomorphicImpl : public ActionInterface { + public: + typedef typename internal::Function::Result Result; + typedef typename internal::Function::ArgumentTuple ArgumentTuple; + + explicit MonomorphicImpl(const Impl& impl) : impl_(impl) {} + + Result Perform(const ArgumentTuple& args) override { + return impl_.template Perform(args); + } + + private: + Impl impl_; + }; + + Impl impl_; +}; + +// Creates an Action from its implementation and returns it. The +// created Action object owns the implementation. +template +Action MakeAction(ActionInterface* impl) { + return Action(impl); +} + +// Creates a polymorphic action from its implementation. This is +// easier to use than the PolymorphicAction constructor as it +// doesn't require you to explicitly write the template argument, e.g. +// +// MakePolymorphicAction(foo); +// vs +// PolymorphicAction(foo); +template +inline PolymorphicAction MakePolymorphicAction(const Impl& impl) { + return PolymorphicAction(impl); +} + +namespace internal { + +// Helper struct to specialize ReturnAction to execute a move instead of a copy +// on return. Useful for move-only types, but could be used on any type. +template +struct ByMoveWrapper { + explicit ByMoveWrapper(T value) : payload(std::move(value)) {} + T payload; +}; + +// The general implementation of Return(R). Specializations follow below. +template +class ReturnAction final { + public: + explicit ReturnAction(R value) : value_(std::move(value)) {} + + template >, // + negation>, // + std::is_convertible, // + std::is_move_constructible>::value>::type> + operator OnceAction() && { // NOLINT + return Impl(std::move(value_)); + } + + template >, // + negation>, // + std::is_convertible, // + std::is_copy_constructible>::value>::type> + operator Action() const { // NOLINT + return Impl(value_); + } + + private: + // Implements the Return(x) action for a mock function that returns type U. + template + class Impl final { + public: + // The constructor used when the return value is allowed to move from the + // input value (i.e. we are converting to OnceAction). + explicit Impl(R&& input_value) + : state_(new State(std::move(input_value))) {} + + // The constructor used when the return value is not allowed to move from + // the input value (i.e. we are converting to Action). + explicit Impl(const R& input_value) : state_(new State(input_value)) {} + + U operator()() && { return std::move(state_->value); } + U operator()() const& { return state_->value; } + + private: + // We put our state on the heap so that the compiler-generated copy/move + // constructors work correctly even when U is a reference-like type. This is + // necessary only because we eagerly create State::value (see the note on + // that symbol for details). If we instead had only the input value as a + // member then the default constructors would work fine. + // + // For example, when R is std::string and U is std::string_view, value is a + // reference to the string backed by input_value. The copy constructor would + // copy both, so that we wind up with a new input_value object (with the + // same contents) and a reference to the *old* input_value object rather + // than the new one. + struct State { + explicit State(const R& input_value_in) + : input_value(input_value_in), + // Make an implicit conversion to Result before initializing the U + // object we store, avoiding calling any explicit constructor of U + // from R. + // + // This simulates the language rules: a function with return type U + // that does `return R()` requires R to be implicitly convertible to + // U, and uses that path for the conversion, even U Result has an + // explicit constructor from R. + value(ImplicitCast_(internal::as_const(input_value))) {} + + // As above, but for the case where we're moving from the ReturnAction + // object because it's being used as a OnceAction. + explicit State(R&& input_value_in) + : input_value(std::move(input_value_in)), + // For the same reason as above we make an implicit conversion to U + // before initializing the value. + // + // Unlike above we provide the input value as an rvalue to the + // implicit conversion because this is a OnceAction: it's fine if it + // wants to consume the input value. + value(ImplicitCast_(std::move(input_value))) {} + + // A copy of the value originally provided by the user. We retain this in + // addition to the value of the mock function's result type below in case + // the latter is a reference-like type. See the std::string_view example + // in the documentation on Return. + R input_value; + + // The value we actually return, as the type returned by the mock function + // itself. + // + // We eagerly initialize this here, rather than lazily doing the implicit + // conversion automatically each time Perform is called, for historical + // reasons: in 2009-11, commit a070cbd91c (Google changelist 13540126) + // made the Action conversion operator eagerly convert the R value to + // U, but without keeping the R alive. This broke the use case discussed + // in the documentation for Return, making reference-like types such as + // std::string_view not safe to use as U where the input type R is a + // value-like type such as std::string. + // + // The example the commit gave was not very clear, nor was the issue + // thread (https://github.com/google/googlemock/issues/86), but it seems + // the worry was about reference-like input types R that flatten to a + // value-like type U when being implicitly converted. An example of this + // is std::vector::reference, which is often a proxy type with an + // reference to the underlying vector: + // + // // Helper method: have the mock function return bools according + // // to the supplied script. + // void SetActions(MockFunction& mock, + // const std::vector& script) { + // for (size_t i = 0; i < script.size(); ++i) { + // EXPECT_CALL(mock, Call(i)).WillOnce(Return(script[i])); + // } + // } + // + // TEST(Foo, Bar) { + // // Set actions using a temporary vector, whose operator[] + // // returns proxy objects that references that will be + // // dangling once the call to SetActions finishes and the + // // vector is destroyed. + // MockFunction mock; + // SetActions(mock, {false, true}); + // + // EXPECT_FALSE(mock.AsStdFunction()(0)); + // EXPECT_TRUE(mock.AsStdFunction()(1)); + // } + // + // This eager conversion helps with a simple case like this, but doesn't + // fully make these types work in general. For example the following still + // uses a dangling reference: + // + // TEST(Foo, Baz) { + // MockFunction()> mock; + // + // // Return the same vector twice, and then the empty vector + // // thereafter. + // auto action = Return(std::initializer_list{ + // "taco", "burrito", + // }); + // + // EXPECT_CALL(mock, Call) + // .WillOnce(action) + // .WillOnce(action) + // .WillRepeatedly(Return(std::vector{})); + // + // EXPECT_THAT(mock.AsStdFunction()(), + // ElementsAre("taco", "burrito")); + // EXPECT_THAT(mock.AsStdFunction()(), + // ElementsAre("taco", "burrito")); + // EXPECT_THAT(mock.AsStdFunction()(), IsEmpty()); + // } + // + U value; + }; + + const std::shared_ptr state_; + }; + + R value_; +}; + +// A specialization of ReturnAction when R is ByMoveWrapper for some T. +// +// This version applies the type system-defeating hack of moving from T even in +// the const call operator, checking at runtime that it isn't called more than +// once, since the user has declared their intent to do so by using ByMove. +template +class ReturnAction> final { + public: + explicit ReturnAction(ByMoveWrapper wrapper) + : state_(new State(std::move(wrapper.payload))) {} + + T operator()() const { + GTEST_CHECK_(!state_->called) + << "A ByMove() action must be performed at most once."; + + state_->called = true; + return std::move(state_->value); + } + + private: + // We store our state on the heap so that we are copyable as required by + // Action, despite the fact that we are stateful and T may not be copyable. + struct State { + explicit State(T&& value_in) : value(std::move(value_in)) {} + + T value; + bool called = false; + }; + + const std::shared_ptr state_; +}; + +// Implements the ReturnNull() action. +class ReturnNullAction { + public: + // Allows ReturnNull() to be used in any pointer-returning function. In C++11 + // this is enforced by returning nullptr, and in non-C++11 by asserting a + // pointer type on compile time. + template + static Result Perform(const ArgumentTuple&) { + return nullptr; + } +}; + +// Implements the Return() action. +class ReturnVoidAction { + public: + // Allows Return() to be used in any void-returning function. + template + static void Perform(const ArgumentTuple&) { + static_assert(std::is_void::value, "Result should be void."); + } +}; + +// Implements the polymorphic ReturnRef(x) action, which can be used +// in any function that returns a reference to the type of x, +// regardless of the argument types. +template +class ReturnRefAction { + public: + // Constructs a ReturnRefAction object from the reference to be returned. + explicit ReturnRefAction(T& ref) : ref_(ref) {} // NOLINT + + // This template type conversion operator allows ReturnRef(x) to be + // used in ANY function that returns a reference to x's type. + template + operator Action() const { + typedef typename Function::Result Result; + // Asserts that the function return type is a reference. This + // catches the user error of using ReturnRef(x) when Return(x) + // should be used, and generates some helpful error message. + static_assert(std::is_reference::value, + "use Return instead of ReturnRef to return a value"); + return Action(new Impl(ref_)); + } + + private: + // Implements the ReturnRef(x) action for a particular function type F. + template + class Impl : public ActionInterface { + public: + typedef typename Function::Result Result; + typedef typename Function::ArgumentTuple ArgumentTuple; + + explicit Impl(T& ref) : ref_(ref) {} // NOLINT + + Result Perform(const ArgumentTuple&) override { return ref_; } + + private: + T& ref_; + }; + + T& ref_; +}; + +// Implements the polymorphic ReturnRefOfCopy(x) action, which can be +// used in any function that returns a reference to the type of x, +// regardless of the argument types. +template +class ReturnRefOfCopyAction { + public: + // Constructs a ReturnRefOfCopyAction object from the reference to + // be returned. + explicit ReturnRefOfCopyAction(const T& value) : value_(value) {} // NOLINT + + // This template type conversion operator allows ReturnRefOfCopy(x) to be + // used in ANY function that returns a reference to x's type. + template + operator Action() const { + typedef typename Function::Result Result; + // Asserts that the function return type is a reference. This + // catches the user error of using ReturnRefOfCopy(x) when Return(x) + // should be used, and generates some helpful error message. + static_assert(std::is_reference::value, + "use Return instead of ReturnRefOfCopy to return a value"); + return Action(new Impl(value_)); + } + + private: + // Implements the ReturnRefOfCopy(x) action for a particular function type F. + template + class Impl : public ActionInterface { + public: + typedef typename Function::Result Result; + typedef typename Function::ArgumentTuple ArgumentTuple; + + explicit Impl(const T& value) : value_(value) {} // NOLINT + + Result Perform(const ArgumentTuple&) override { return value_; } + + private: + T value_; + }; + + const T value_; +}; + +// Implements the polymorphic ReturnRoundRobin(v) action, which can be +// used in any function that returns the element_type of v. +template +class ReturnRoundRobinAction { + public: + explicit ReturnRoundRobinAction(std::vector values) { + GTEST_CHECK_(!values.empty()) + << "ReturnRoundRobin requires at least one element."; + state_->values = std::move(values); + } + + template + T operator()(Args&&...) const { + return state_->Next(); + } + + private: + struct State { + T Next() { + T ret_val = values[i++]; + if (i == values.size()) i = 0; + return ret_val; + } + + std::vector values; + size_t i = 0; + }; + std::shared_ptr state_ = std::make_shared(); +}; + +// Implements the polymorphic DoDefault() action. +class DoDefaultAction { + public: + // This template type conversion operator allows DoDefault() to be + // used in any function. + template + operator Action() const { + return Action(); + } // NOLINT +}; + +// Implements the Assign action to set a given pointer referent to a +// particular value. +template +class AssignAction { + public: + AssignAction(T1* ptr, T2 value) : ptr_(ptr), value_(value) {} + + template + void Perform(const ArgumentTuple& /* args */) const { + *ptr_ = value_; + } + + private: + T1* const ptr_; + const T2 value_; +}; + +#if !GTEST_OS_WINDOWS_MOBILE + +// Implements the SetErrnoAndReturn action to simulate return from +// various system calls and libc functions. +template +class SetErrnoAndReturnAction { + public: + SetErrnoAndReturnAction(int errno_value, T result) + : errno_(errno_value), result_(result) {} + template + Result Perform(const ArgumentTuple& /* args */) const { + errno = errno_; + return result_; + } + + private: + const int errno_; + const T result_; +}; + +#endif // !GTEST_OS_WINDOWS_MOBILE + +// Implements the SetArgumentPointee(x) action for any function +// whose N-th argument (0-based) is a pointer to x's type. +template +struct SetArgumentPointeeAction { + A value; + + template + void operator()(const Args&... args) const { + *::std::get(std::tie(args...)) = value; + } +}; + +// Implements the Invoke(object_ptr, &Class::Method) action. +template +struct InvokeMethodAction { + Class* const obj_ptr; + const MethodPtr method_ptr; + + template + auto operator()(Args&&... args) const + -> decltype((obj_ptr->*method_ptr)(std::forward(args)...)) { + return (obj_ptr->*method_ptr)(std::forward(args)...); + } +}; + +// Implements the InvokeWithoutArgs(f) action. The template argument +// FunctionImpl is the implementation type of f, which can be either a +// function pointer or a functor. InvokeWithoutArgs(f) can be used as an +// Action as long as f's type is compatible with F. +template +struct InvokeWithoutArgsAction { + FunctionImpl function_impl; + + // Allows InvokeWithoutArgs(f) to be used as any action whose type is + // compatible with f. + template + auto operator()(const Args&...) -> decltype(function_impl()) { + return function_impl(); + } +}; + +// Implements the InvokeWithoutArgs(object_ptr, &Class::Method) action. +template +struct InvokeMethodWithoutArgsAction { + Class* const obj_ptr; + const MethodPtr method_ptr; + + using ReturnType = + decltype((std::declval()->*std::declval())()); + + template + ReturnType operator()(const Args&...) const { + return (obj_ptr->*method_ptr)(); + } +}; + +// Implements the IgnoreResult(action) action. +template +class IgnoreResultAction { + public: + explicit IgnoreResultAction(const A& action) : action_(action) {} + + template + operator Action() const { + // Assert statement belongs here because this is the best place to verify + // conditions on F. It produces the clearest error messages + // in most compilers. + // Impl really belongs in this scope as a local class but can't + // because MSVC produces duplicate symbols in different translation units + // in this case. Until MS fixes that bug we put Impl into the class scope + // and put the typedef both here (for use in assert statement) and + // in the Impl class. But both definitions must be the same. + typedef typename internal::Function::Result Result; + + // Asserts at compile time that F returns void. + static_assert(std::is_void::value, "Result type should be void."); + + return Action(new Impl(action_)); + } + + private: + template + class Impl : public ActionInterface { + public: + typedef typename internal::Function::Result Result; + typedef typename internal::Function::ArgumentTuple ArgumentTuple; + + explicit Impl(const A& action) : action_(action) {} + + void Perform(const ArgumentTuple& args) override { + // Performs the action and ignores its result. + action_.Perform(args); + } + + private: + // Type OriginalFunction is the same as F except that its return + // type is IgnoredValue. + typedef + typename internal::Function::MakeResultIgnoredValue OriginalFunction; + + const Action action_; + }; + + const A action_; +}; + +template +struct WithArgsAction { + InnerAction inner_action; + + // The signature of the function as seen by the inner action, given an out + // action with the given result and argument types. + template + using InnerSignature = + R(typename std::tuple_element>::type...); + + // Rather than a call operator, we must define conversion operators to + // particular action types. This is necessary for embedded actions like + // DoDefault(), which rely on an action conversion operators rather than + // providing a call operator because even with a particular set of arguments + // they don't have a fixed return type. + + template >...)>>::value, + int>::type = 0> + operator OnceAction() && { // NOLINT + struct OA { + OnceAction> inner_action; + + R operator()(Args&&... args) && { + return std::move(inner_action) + .Call(std::get( + std::forward_as_tuple(std::forward(args)...))...); + } + }; + + return OA{std::move(inner_action)}; + } + + template >...)>>::value, + int>::type = 0> + operator Action() const { // NOLINT + Action> converted(inner_action); + + return [converted](Args&&... args) -> R { + return converted.Perform(std::forward_as_tuple( + std::get(std::forward_as_tuple(std::forward(args)...))...)); + }; + } +}; + +template +class DoAllAction; + +// Base case: only a single action. +template +class DoAllAction { + public: + struct UserConstructorTag {}; + + template + explicit DoAllAction(UserConstructorTag, T&& action) + : final_action_(std::forward(action)) {} + + // Rather than a call operator, we must define conversion operators to + // particular action types. This is necessary for embedded actions like + // DoDefault(), which rely on an action conversion operators rather than + // providing a call operator because even with a particular set of arguments + // they don't have a fixed return type. + + template >::value, + int>::type = 0> + operator OnceAction() && { // NOLINT + return std::move(final_action_); + } + + template < + typename R, typename... Args, + typename std::enable_if< + std::is_convertible>::value, + int>::type = 0> + operator Action() const { // NOLINT + return final_action_; + } + + private: + FinalAction final_action_; +}; + +// Recursive case: support N actions by calling the initial action and then +// calling through to the base class containing N-1 actions. +template +class DoAllAction + : private DoAllAction { + private: + using Base = DoAllAction; + + // The type of reference that should be provided to an initial action for a + // mocked function parameter of type T. + // + // There are two quirks here: + // + // * Unlike most forwarding functions, we pass scalars through by value. + // This isn't strictly necessary because an lvalue reference would work + // fine too and be consistent with other non-reference types, but it's + // perhaps less surprising. + // + // For example if the mocked function has signature void(int), then it + // might seem surprising for the user's initial action to need to be + // convertible to Action. This is perhaps less + // surprising for a non-scalar type where there may be a performance + // impact, or it might even be impossible, to pass by value. + // + // * More surprisingly, `const T&` is often not a const reference type. + // By the reference collapsing rules in C++17 [dcl.ref]/6, if T refers to + // U& or U&& for some non-scalar type U, then InitialActionArgType is + // U&. In other words, we may hand over a non-const reference. + // + // So for example, given some non-scalar type Obj we have the following + // mappings: + // + // T InitialActionArgType + // ------- ----------------------- + // Obj const Obj& + // Obj& Obj& + // Obj&& Obj& + // const Obj const Obj& + // const Obj& const Obj& + // const Obj&& const Obj& + // + // In other words, the initial actions get a mutable view of an non-scalar + // argument if and only if the mock function itself accepts a non-const + // reference type. They are never given an rvalue reference to an + // non-scalar type. + // + // This situation makes sense if you imagine use with a matcher that is + // designed to write through a reference. For example, if the caller wants + // to fill in a reference argument and then return a canned value: + // + // EXPECT_CALL(mock, Call) + // .WillOnce(DoAll(SetArgReferee<0>(17), Return(19))); + // + template + using InitialActionArgType = + typename std::conditional::value, T, const T&>::type; + + public: + struct UserConstructorTag {}; + + template + explicit DoAllAction(UserConstructorTag, T&& initial_action, + U&&... other_actions) + : Base({}, std::forward(other_actions)...), + initial_action_(std::forward(initial_action)) {} + + template ...)>>, + std::is_convertible>>::value, + int>::type = 0> + operator OnceAction() && { // NOLINT + // Return an action that first calls the initial action with arguments + // filtered through InitialActionArgType, then forwards arguments directly + // to the base class to deal with the remaining actions. + struct OA { + OnceAction...)> initial_action; + OnceAction remaining_actions; + + R operator()(Args... args) && { + std::move(initial_action) + .Call(static_cast>(args)...); + + return std::move(remaining_actions).Call(std::forward(args)...); + } + }; + + return OA{ + std::move(initial_action_), + std::move(static_cast(*this)), + }; + } + + template < + typename R, typename... Args, + typename std::enable_if< + conjunction< + // Both the initial action and the rest must support conversion to + // Action. + std::is_convertible...)>>, + std::is_convertible>>::value, + int>::type = 0> + operator Action() const { // NOLINT + // Return an action that first calls the initial action with arguments + // filtered through InitialActionArgType, then forwards arguments directly + // to the base class to deal with the remaining actions. + struct OA { + Action...)> initial_action; + Action remaining_actions; + + R operator()(Args... args) const { + initial_action.Perform(std::forward_as_tuple( + static_cast>(args)...)); + + return remaining_actions.Perform( + std::forward_as_tuple(std::forward(args)...)); + } + }; + + return OA{ + initial_action_, + static_cast(*this), + }; + } + + private: + InitialAction initial_action_; +}; + +template +struct ReturnNewAction { + T* operator()() const { + return internal::Apply( + [](const Params&... unpacked_params) { + return new T(unpacked_params...); + }, + params); + } + std::tuple params; +}; + +template +struct ReturnArgAction { + template ::type> + auto operator()(Args&&... args) const -> decltype(std::get( + std::forward_as_tuple(std::forward(args)...))) { + return std::get(std::forward_as_tuple(std::forward(args)...)); + } +}; + +template +struct SaveArgAction { + Ptr pointer; + + template + void operator()(const Args&... args) const { + *pointer = std::get(std::tie(args...)); + } +}; + +template +struct SaveArgPointeeAction { + Ptr pointer; + + template + void operator()(const Args&... args) const { + *pointer = *std::get(std::tie(args...)); + } +}; + +template +struct SetArgRefereeAction { + T value; + + template + void operator()(Args&&... args) const { + using argk_type = + typename ::std::tuple_element>::type; + static_assert(std::is_lvalue_reference::value, + "Argument must be a reference type."); + std::get(std::tie(args...)) = value; + } +}; + +template +struct SetArrayArgumentAction { + I1 first; + I2 last; + + template + void operator()(const Args&... args) const { + auto value = std::get(std::tie(args...)); + for (auto it = first; it != last; ++it, (void)++value) { + *value = *it; + } + } +}; + +template +struct DeleteArgAction { + template + void operator()(const Args&... args) const { + delete std::get(std::tie(args...)); + } +}; + +template +struct ReturnPointeeAction { + Ptr pointer; + template + auto operator()(const Args&...) const -> decltype(*pointer) { + return *pointer; + } +}; + +#if GTEST_HAS_EXCEPTIONS +template +struct ThrowAction { + T exception; + // We use a conversion operator to adapt to any return type. + template + operator Action() const { // NOLINT + T copy = exception; + return [copy](Args...) -> R { throw copy; }; + } +}; +#endif // GTEST_HAS_EXCEPTIONS + +} // namespace internal + +// An Unused object can be implicitly constructed from ANY value. +// This is handy when defining actions that ignore some or all of the +// mock function arguments. For example, given +// +// MOCK_METHOD3(Foo, double(const string& label, double x, double y)); +// MOCK_METHOD3(Bar, double(int index, double x, double y)); +// +// instead of +// +// double DistanceToOriginWithLabel(const string& label, double x, double y) { +// return sqrt(x*x + y*y); +// } +// double DistanceToOriginWithIndex(int index, double x, double y) { +// return sqrt(x*x + y*y); +// } +// ... +// EXPECT_CALL(mock, Foo("abc", _, _)) +// .WillOnce(Invoke(DistanceToOriginWithLabel)); +// EXPECT_CALL(mock, Bar(5, _, _)) +// .WillOnce(Invoke(DistanceToOriginWithIndex)); +// +// you could write +// +// // We can declare any uninteresting argument as Unused. +// double DistanceToOrigin(Unused, double x, double y) { +// return sqrt(x*x + y*y); +// } +// ... +// EXPECT_CALL(mock, Foo("abc", _, _)).WillOnce(Invoke(DistanceToOrigin)); +// EXPECT_CALL(mock, Bar(5, _, _)).WillOnce(Invoke(DistanceToOrigin)); +typedef internal::IgnoredValue Unused; + +// Creates an action that does actions a1, a2, ..., sequentially in +// each invocation. All but the last action will have a readonly view of the +// arguments. +template +internal::DoAllAction::type...> DoAll( + Action&&... action) { + return internal::DoAllAction::type...>( + {}, std::forward(action)...); +} + +// WithArg(an_action) creates an action that passes the k-th +// (0-based) argument of the mock function to an_action and performs +// it. It adapts an action accepting one argument to one that accepts +// multiple arguments. For convenience, we also provide +// WithArgs(an_action) (defined below) as a synonym. +template +internal::WithArgsAction::type, k> WithArg( + InnerAction&& action) { + return {std::forward(action)}; +} + +// WithArgs(an_action) creates an action that passes +// the selected arguments of the mock function to an_action and +// performs it. It serves as an adaptor between actions with +// different argument lists. +template +internal::WithArgsAction::type, k, ks...> +WithArgs(InnerAction&& action) { + return {std::forward(action)}; +} + +// WithoutArgs(inner_action) can be used in a mock function with a +// non-empty argument list to perform inner_action, which takes no +// argument. In other words, it adapts an action accepting no +// argument to one that accepts (and ignores) arguments. +template +internal::WithArgsAction::type> WithoutArgs( + InnerAction&& action) { + return {std::forward(action)}; +} + +// Creates an action that returns a value. +// +// The returned type can be used with a mock function returning a non-void, +// non-reference type U as follows: +// +// * If R is convertible to U and U is move-constructible, then the action can +// be used with WillOnce. +// +// * If const R& is convertible to U and U is copy-constructible, then the +// action can be used with both WillOnce and WillRepeatedly. +// +// The mock expectation contains the R value from which the U return value is +// constructed (a move/copy of the argument to Return). This means that the R +// value will survive at least until the mock object's expectations are cleared +// or the mock object is destroyed, meaning that U can safely be a +// reference-like type such as std::string_view: +// +// // The mock function returns a view of a copy of the string fed to +// // Return. The view is valid even after the action is performed. +// MockFunction mock; +// EXPECT_CALL(mock, Call).WillOnce(Return(std::string("taco"))); +// const std::string_view result = mock.AsStdFunction()(); +// EXPECT_EQ("taco", result); +// +template +internal::ReturnAction Return(R value) { + return internal::ReturnAction(std::move(value)); +} + +// Creates an action that returns NULL. +inline PolymorphicAction ReturnNull() { + return MakePolymorphicAction(internal::ReturnNullAction()); +} + +// Creates an action that returns from a void function. +inline PolymorphicAction Return() { + return MakePolymorphicAction(internal::ReturnVoidAction()); +} + +// Creates an action that returns the reference to a variable. +template +inline internal::ReturnRefAction ReturnRef(R& x) { // NOLINT + return internal::ReturnRefAction(x); +} + +// Prevent using ReturnRef on reference to temporary. +template +internal::ReturnRefAction ReturnRef(R&&) = delete; + +// Creates an action that returns the reference to a copy of the +// argument. The copy is created when the action is constructed and +// lives as long as the action. +template +inline internal::ReturnRefOfCopyAction ReturnRefOfCopy(const R& x) { + return internal::ReturnRefOfCopyAction(x); +} + +// DEPRECATED: use Return(x) directly with WillOnce. +// +// Modifies the parent action (a Return() action) to perform a move of the +// argument instead of a copy. +// Return(ByMove()) actions can only be executed once and will assert this +// invariant. +template +internal::ByMoveWrapper ByMove(R x) { + return internal::ByMoveWrapper(std::move(x)); +} + +// Creates an action that returns an element of `vals`. Calling this action will +// repeatedly return the next value from `vals` until it reaches the end and +// will restart from the beginning. +template +internal::ReturnRoundRobinAction ReturnRoundRobin(std::vector vals) { + return internal::ReturnRoundRobinAction(std::move(vals)); +} + +// Creates an action that returns an element of `vals`. Calling this action will +// repeatedly return the next value from `vals` until it reaches the end and +// will restart from the beginning. +template +internal::ReturnRoundRobinAction ReturnRoundRobin( + std::initializer_list vals) { + return internal::ReturnRoundRobinAction(std::vector(vals)); +} + +// Creates an action that does the default action for the give mock function. +inline internal::DoDefaultAction DoDefault() { + return internal::DoDefaultAction(); +} + +// Creates an action that sets the variable pointed by the N-th +// (0-based) function argument to 'value'. +template +internal::SetArgumentPointeeAction SetArgPointee(T value) { + return {std::move(value)}; +} + +// The following version is DEPRECATED. +template +internal::SetArgumentPointeeAction SetArgumentPointee(T value) { + return {std::move(value)}; +} + +// Creates an action that sets a pointer referent to a given value. +template +PolymorphicAction> Assign(T1* ptr, T2 val) { + return MakePolymorphicAction(internal::AssignAction(ptr, val)); +} + +#if !GTEST_OS_WINDOWS_MOBILE + +// Creates an action that sets errno and returns the appropriate error. +template +PolymorphicAction> SetErrnoAndReturn( + int errval, T result) { + return MakePolymorphicAction( + internal::SetErrnoAndReturnAction(errval, result)); +} + +#endif // !GTEST_OS_WINDOWS_MOBILE + +// Various overloads for Invoke(). + +// Legacy function. +// Actions can now be implicitly constructed from callables. No need to create +// wrapper objects. +// This function exists for backwards compatibility. +template +typename std::decay::type Invoke(FunctionImpl&& function_impl) { + return std::forward(function_impl); +} + +// Creates an action that invokes the given method on the given object +// with the mock function's arguments. +template +internal::InvokeMethodAction Invoke(Class* obj_ptr, + MethodPtr method_ptr) { + return {obj_ptr, method_ptr}; +} + +// Creates an action that invokes 'function_impl' with no argument. +template +internal::InvokeWithoutArgsAction::type> +InvokeWithoutArgs(FunctionImpl function_impl) { + return {std::move(function_impl)}; +} + +// Creates an action that invokes the given method on the given object +// with no argument. +template +internal::InvokeMethodWithoutArgsAction InvokeWithoutArgs( + Class* obj_ptr, MethodPtr method_ptr) { + return {obj_ptr, method_ptr}; +} + +// Creates an action that performs an_action and throws away its +// result. In other words, it changes the return type of an_action to +// void. an_action MUST NOT return void, or the code won't compile. +template +inline internal::IgnoreResultAction IgnoreResult(const A& an_action) { + return internal::IgnoreResultAction(an_action); +} + +// Creates a reference wrapper for the given L-value. If necessary, +// you can explicitly specify the type of the reference. For example, +// suppose 'derived' is an object of type Derived, ByRef(derived) +// would wrap a Derived&. If you want to wrap a const Base& instead, +// where Base is a base class of Derived, just write: +// +// ByRef(derived) +// +// N.B. ByRef is redundant with std::ref, std::cref and std::reference_wrapper. +// However, it may still be used for consistency with ByMove(). +template +inline ::std::reference_wrapper ByRef(T& l_value) { // NOLINT + return ::std::reference_wrapper(l_value); +} + +// The ReturnNew(a1, a2, ..., a_k) action returns a pointer to a new +// instance of type T, constructed on the heap with constructor arguments +// a1, a2, ..., and a_k. The caller assumes ownership of the returned value. +template +internal::ReturnNewAction::type...> ReturnNew( + Params&&... params) { + return {std::forward_as_tuple(std::forward(params)...)}; +} + +// Action ReturnArg() returns the k-th argument of the mock function. +template +internal::ReturnArgAction ReturnArg() { + return {}; +} + +// Action SaveArg(pointer) saves the k-th (0-based) argument of the +// mock function to *pointer. +template +internal::SaveArgAction SaveArg(Ptr pointer) { + return {pointer}; +} + +// Action SaveArgPointee(pointer) saves the value pointed to +// by the k-th (0-based) argument of the mock function to *pointer. +template +internal::SaveArgPointeeAction SaveArgPointee(Ptr pointer) { + return {pointer}; +} + +// Action SetArgReferee(value) assigns 'value' to the variable +// referenced by the k-th (0-based) argument of the mock function. +template +internal::SetArgRefereeAction::type> SetArgReferee( + T&& value) { + return {std::forward(value)}; +} + +// Action SetArrayArgument(first, last) copies the elements in +// source range [first, last) to the array pointed to by the k-th +// (0-based) argument, which can be either a pointer or an +// iterator. The action does not take ownership of the elements in the +// source range. +template +internal::SetArrayArgumentAction SetArrayArgument(I1 first, + I2 last) { + return {first, last}; +} + +// Action DeleteArg() deletes the k-th (0-based) argument of the mock +// function. +template +internal::DeleteArgAction DeleteArg() { + return {}; +} + +// This action returns the value pointed to by 'pointer'. +template +internal::ReturnPointeeAction ReturnPointee(Ptr pointer) { + return {pointer}; +} + +// Action Throw(exception) can be used in a mock function of any type +// to throw the given exception. Any copyable value can be thrown. +#if GTEST_HAS_EXCEPTIONS +template +internal::ThrowAction::type> Throw(T&& exception) { + return {std::forward(exception)}; +} +#endif // GTEST_HAS_EXCEPTIONS + +namespace internal { + +// A macro from the ACTION* family (defined later in gmock-generated-actions.h) +// defines an action that can be used in a mock function. Typically, +// these actions only care about a subset of the arguments of the mock +// function. For example, if such an action only uses the second +// argument, it can be used in any mock function that takes >= 2 +// arguments where the type of the second argument is compatible. +// +// Therefore, the action implementation must be prepared to take more +// arguments than it needs. The ExcessiveArg type is used to +// represent those excessive arguments. In order to keep the compiler +// error messages tractable, we define it in the testing namespace +// instead of testing::internal. However, this is an INTERNAL TYPE +// and subject to change without notice, so a user MUST NOT USE THIS +// TYPE DIRECTLY. +struct ExcessiveArg {}; + +// Builds an implementation of an Action<> for some particular signature, using +// a class defined by an ACTION* macro. +template +struct ActionImpl; + +template +struct ImplBase { + struct Holder { + // Allows each copy of the Action<> to get to the Impl. + explicit operator const Impl&() const { return *ptr; } + std::shared_ptr ptr; + }; + using type = typename std::conditional::value, + Impl, Holder>::type; +}; + +template +struct ActionImpl : ImplBase::type { + using Base = typename ImplBase::type; + using function_type = R(Args...); + using args_type = std::tuple; + + ActionImpl() = default; // Only defined if appropriate for Base. + explicit ActionImpl(std::shared_ptr impl) : Base{std::move(impl)} {} + + R operator()(Args&&... arg) const { + static constexpr size_t kMaxArgs = + sizeof...(Args) <= 10 ? sizeof...(Args) : 10; + return Apply(MakeIndexSequence{}, + MakeIndexSequence<10 - kMaxArgs>{}, + args_type{std::forward(arg)...}); + } + + template + R Apply(IndexSequence, IndexSequence, + const args_type& args) const { + // Impl need not be specific to the signature of action being implemented; + // only the implementing function body needs to have all of the specific + // types instantiated. Up to 10 of the args that are provided by the + // args_type get passed, followed by a dummy of unspecified type for the + // remainder up to 10 explicit args. + static constexpr ExcessiveArg kExcessArg{}; + return static_cast(*this) + .template gmock_PerformImpl< + /*function_type=*/function_type, /*return_type=*/R, + /*args_type=*/args_type, + /*argN_type=*/ + typename std::tuple_element::type...>( + /*args=*/args, std::get(args)..., + ((void)excess_id, kExcessArg)...); + } +}; + +// Stores a default-constructed Impl as part of the Action<>'s +// std::function<>. The Impl should be trivial to copy. +template +::testing::Action MakeAction() { + return ::testing::Action(ActionImpl()); +} + +// Stores just the one given instance of Impl. +template +::testing::Action MakeAction(std::shared_ptr impl) { + return ::testing::Action(ActionImpl(std::move(impl))); +} + +#define GMOCK_INTERNAL_ARG_UNUSED(i, data, el) \ + , const arg##i##_type& arg##i GTEST_ATTRIBUTE_UNUSED_ +#define GMOCK_ACTION_ARG_TYPES_AND_NAMES_UNUSED_ \ + const args_type& args GTEST_ATTRIBUTE_UNUSED_ GMOCK_PP_REPEAT( \ + GMOCK_INTERNAL_ARG_UNUSED, , 10) + +#define GMOCK_INTERNAL_ARG(i, data, el) , const arg##i##_type& arg##i +#define GMOCK_ACTION_ARG_TYPES_AND_NAMES_ \ + const args_type& args GMOCK_PP_REPEAT(GMOCK_INTERNAL_ARG, , 10) + +#define GMOCK_INTERNAL_TEMPLATE_ARG(i, data, el) , typename arg##i##_type +#define GMOCK_ACTION_TEMPLATE_ARGS_NAMES_ \ + GMOCK_PP_TAIL(GMOCK_PP_REPEAT(GMOCK_INTERNAL_TEMPLATE_ARG, , 10)) + +#define GMOCK_INTERNAL_TYPENAME_PARAM(i, data, param) , typename param##_type +#define GMOCK_ACTION_TYPENAME_PARAMS_(params) \ + GMOCK_PP_TAIL(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_TYPENAME_PARAM, , params)) + +#define GMOCK_INTERNAL_TYPE_PARAM(i, data, param) , param##_type +#define GMOCK_ACTION_TYPE_PARAMS_(params) \ + GMOCK_PP_TAIL(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_TYPE_PARAM, , params)) + +#define GMOCK_INTERNAL_TYPE_GVALUE_PARAM(i, data, param) \ + , param##_type gmock_p##i +#define GMOCK_ACTION_TYPE_GVALUE_PARAMS_(params) \ + GMOCK_PP_TAIL(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_TYPE_GVALUE_PARAM, , params)) + +#define GMOCK_INTERNAL_GVALUE_PARAM(i, data, param) \ + , std::forward(gmock_p##i) +#define GMOCK_ACTION_GVALUE_PARAMS_(params) \ + GMOCK_PP_TAIL(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_GVALUE_PARAM, , params)) + +#define GMOCK_INTERNAL_INIT_PARAM(i, data, param) \ + , param(::std::forward(gmock_p##i)) +#define GMOCK_ACTION_INIT_PARAMS_(params) \ + GMOCK_PP_TAIL(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_INIT_PARAM, , params)) + +#define GMOCK_INTERNAL_FIELD_PARAM(i, data, param) param##_type param; +#define GMOCK_ACTION_FIELD_PARAMS_(params) \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_FIELD_PARAM, , params) + +#define GMOCK_INTERNAL_ACTION(name, full_name, params) \ + template \ + class full_name { \ + public: \ + explicit full_name(GMOCK_ACTION_TYPE_GVALUE_PARAMS_(params)) \ + : impl_(std::make_shared( \ + GMOCK_ACTION_GVALUE_PARAMS_(params))) {} \ + full_name(const full_name&) = default; \ + full_name(full_name&&) noexcept = default; \ + template \ + operator ::testing::Action() const { \ + return ::testing::internal::MakeAction(impl_); \ + } \ + \ + private: \ + class gmock_Impl { \ + public: \ + explicit gmock_Impl(GMOCK_ACTION_TYPE_GVALUE_PARAMS_(params)) \ + : GMOCK_ACTION_INIT_PARAMS_(params) {} \ + template \ + return_type gmock_PerformImpl(GMOCK_ACTION_ARG_TYPES_AND_NAMES_) const; \ + GMOCK_ACTION_FIELD_PARAMS_(params) \ + }; \ + std::shared_ptr impl_; \ + }; \ + template \ + inline full_name name( \ + GMOCK_ACTION_TYPE_GVALUE_PARAMS_(params)) GTEST_MUST_USE_RESULT_; \ + template \ + inline full_name name( \ + GMOCK_ACTION_TYPE_GVALUE_PARAMS_(params)) { \ + return full_name( \ + GMOCK_ACTION_GVALUE_PARAMS_(params)); \ + } \ + template \ + template \ + return_type \ + full_name::gmock_Impl::gmock_PerformImpl( \ + GMOCK_ACTION_ARG_TYPES_AND_NAMES_UNUSED_) const + +} // namespace internal + +// Similar to GMOCK_INTERNAL_ACTION, but no bound parameters are stored. +#define ACTION(name) \ + class name##Action { \ + public: \ + explicit name##Action() noexcept {} \ + name##Action(const name##Action&) noexcept {} \ + template \ + operator ::testing::Action() const { \ + return ::testing::internal::MakeAction(); \ + } \ + \ + private: \ + class gmock_Impl { \ + public: \ + template \ + return_type gmock_PerformImpl(GMOCK_ACTION_ARG_TYPES_AND_NAMES_) const; \ + }; \ + }; \ + inline name##Action name() GTEST_MUST_USE_RESULT_; \ + inline name##Action name() { return name##Action(); } \ + template \ + return_type name##Action::gmock_Impl::gmock_PerformImpl( \ + GMOCK_ACTION_ARG_TYPES_AND_NAMES_UNUSED_) const + +#define ACTION_P(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP, (__VA_ARGS__)) + +#define ACTION_P2(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP2, (__VA_ARGS__)) + +#define ACTION_P3(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP3, (__VA_ARGS__)) + +#define ACTION_P4(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP4, (__VA_ARGS__)) + +#define ACTION_P5(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP5, (__VA_ARGS__)) + +#define ACTION_P6(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP6, (__VA_ARGS__)) + +#define ACTION_P7(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP7, (__VA_ARGS__)) + +#define ACTION_P8(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP8, (__VA_ARGS__)) + +#define ACTION_P9(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP9, (__VA_ARGS__)) + +#define ACTION_P10(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP10, (__VA_ARGS__)) + +} // namespace testing + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +#endif // GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_ACTIONS_H_ diff --git a/3rdparty/googletest-1.13.0/googlemock/include/gmock/gmock-cardinalities.h b/3rdparty/googletest-1.13.0/googlemock/include/gmock/gmock-cardinalities.h new file mode 100644 index 0000000000000000000000000000000000000000..b6ab648e50a649257120e62fdc404e8e5ba2c1d9 --- /dev/null +++ b/3rdparty/googletest-1.13.0/googlemock/include/gmock/gmock-cardinalities.h @@ -0,0 +1,159 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Google Mock - a framework for writing C++ mock classes. +// +// This file implements some commonly used cardinalities. More +// cardinalities can be defined by the user implementing the +// CardinalityInterface interface if necessary. + +// IWYU pragma: private, include "gmock/gmock.h" +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_CARDINALITIES_H_ +#define GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_CARDINALITIES_H_ + +#include + +#include +#include // NOLINT + +#include "gmock/internal/gmock-port.h" +#include "gtest/gtest.h" + +GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ +/* class A needs to have dll-interface to be used by clients of class B */) + +namespace testing { + +// To implement a cardinality Foo, define: +// 1. a class FooCardinality that implements the +// CardinalityInterface interface, and +// 2. a factory function that creates a Cardinality object from a +// const FooCardinality*. +// +// The two-level delegation design follows that of Matcher, providing +// consistency for extension developers. It also eases ownership +// management as Cardinality objects can now be copied like plain values. + +// The implementation of a cardinality. +class CardinalityInterface { + public: + virtual ~CardinalityInterface() {} + + // Conservative estimate on the lower/upper bound of the number of + // calls allowed. + virtual int ConservativeLowerBound() const { return 0; } + virtual int ConservativeUpperBound() const { return INT_MAX; } + + // Returns true if and only if call_count calls will satisfy this + // cardinality. + virtual bool IsSatisfiedByCallCount(int call_count) const = 0; + + // Returns true if and only if call_count calls will saturate this + // cardinality. + virtual bool IsSaturatedByCallCount(int call_count) const = 0; + + // Describes self to an ostream. + virtual void DescribeTo(::std::ostream* os) const = 0; +}; + +// A Cardinality is a copyable and IMMUTABLE (except by assignment) +// object that specifies how many times a mock function is expected to +// be called. The implementation of Cardinality is just a std::shared_ptr +// to const CardinalityInterface. Don't inherit from Cardinality! +class GTEST_API_ Cardinality { + public: + // Constructs a null cardinality. Needed for storing Cardinality + // objects in STL containers. + Cardinality() {} + + // Constructs a Cardinality from its implementation. + explicit Cardinality(const CardinalityInterface* impl) : impl_(impl) {} + + // Conservative estimate on the lower/upper bound of the number of + // calls allowed. + int ConservativeLowerBound() const { return impl_->ConservativeLowerBound(); } + int ConservativeUpperBound() const { return impl_->ConservativeUpperBound(); } + + // Returns true if and only if call_count calls will satisfy this + // cardinality. + bool IsSatisfiedByCallCount(int call_count) const { + return impl_->IsSatisfiedByCallCount(call_count); + } + + // Returns true if and only if call_count calls will saturate this + // cardinality. + bool IsSaturatedByCallCount(int call_count) const { + return impl_->IsSaturatedByCallCount(call_count); + } + + // Returns true if and only if call_count calls will over-saturate this + // cardinality, i.e. exceed the maximum number of allowed calls. + bool IsOverSaturatedByCallCount(int call_count) const { + return impl_->IsSaturatedByCallCount(call_count) && + !impl_->IsSatisfiedByCallCount(call_count); + } + + // Describes self to an ostream + void DescribeTo(::std::ostream* os) const { impl_->DescribeTo(os); } + + // Describes the given actual call count to an ostream. + static void DescribeActualCallCountTo(int actual_call_count, + ::std::ostream* os); + + private: + std::shared_ptr impl_; +}; + +// Creates a cardinality that allows at least n calls. +GTEST_API_ Cardinality AtLeast(int n); + +// Creates a cardinality that allows at most n calls. +GTEST_API_ Cardinality AtMost(int n); + +// Creates a cardinality that allows any number of calls. +GTEST_API_ Cardinality AnyNumber(); + +// Creates a cardinality that allows between min and max calls. +GTEST_API_ Cardinality Between(int min, int max); + +// Creates a cardinality that allows exactly n calls. +GTEST_API_ Cardinality Exactly(int n); + +// Creates a cardinality from its implementation. +inline Cardinality MakeCardinality(const CardinalityInterface* c) { + return Cardinality(c); +} + +} // namespace testing + +GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 + +#endif // GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_CARDINALITIES_H_ diff --git a/3rdparty/googletest-1.13.0/googlemock/include/gmock/gmock-function-mocker.h b/3rdparty/googletest-1.13.0/googlemock/include/gmock/gmock-function-mocker.h new file mode 100644 index 0000000000000000000000000000000000000000..73065493b38b4fbe8bedaaeb690779769708cac1 --- /dev/null +++ b/3rdparty/googletest-1.13.0/googlemock/include/gmock/gmock-function-mocker.h @@ -0,0 +1,517 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Google Mock - a framework for writing C++ mock classes. +// +// This file implements MOCK_METHOD. + +// IWYU pragma: private, include "gmock/gmock.h" +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_FUNCTION_MOCKER_H_ +#define GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_FUNCTION_MOCKER_H_ + +#include // IWYU pragma: keep +#include // IWYU pragma: keep + +#include "gmock/gmock-spec-builders.h" +#include "gmock/internal/gmock-internal-utils.h" +#include "gmock/internal/gmock-pp.h" + +namespace testing { +namespace internal { +template +using identity_t = T; + +template +struct ThisRefAdjuster { + template + using AdjustT = typename std::conditional< + std::is_const::type>::value, + typename std::conditional::value, + const T&, const T&&>::type, + typename std::conditional::value, T&, + T&&>::type>::type; + + template + static AdjustT Adjust(const MockType& mock) { + return static_cast>(const_cast(mock)); + } +}; + +constexpr bool PrefixOf(const char* a, const char* b) { + return *a == 0 || (*a == *b && internal::PrefixOf(a + 1, b + 1)); +} + +template +constexpr bool StartsWith(const char (&prefix)[N], const char (&str)[M]) { + return N <= M && internal::PrefixOf(prefix, str); +} + +template +constexpr bool EndsWith(const char (&suffix)[N], const char (&str)[M]) { + return N <= M && internal::PrefixOf(suffix, str + M - N); +} + +template +constexpr bool Equals(const char (&a)[N], const char (&b)[M]) { + return N == M && internal::PrefixOf(a, b); +} + +template +constexpr bool ValidateSpec(const char (&spec)[N]) { + return internal::Equals("const", spec) || + internal::Equals("override", spec) || + internal::Equals("final", spec) || + internal::Equals("noexcept", spec) || + (internal::StartsWith("noexcept(", spec) && + internal::EndsWith(")", spec)) || + internal::Equals("ref(&)", spec) || + internal::Equals("ref(&&)", spec) || + (internal::StartsWith("Calltype(", spec) && + internal::EndsWith(")", spec)); +} + +} // namespace internal + +// The style guide prohibits "using" statements in a namespace scope +// inside a header file. However, the FunctionMocker class template +// is meant to be defined in the ::testing namespace. The following +// line is just a trick for working around a bug in MSVC 8.0, which +// cannot handle it if we define FunctionMocker in ::testing. +using internal::FunctionMocker; +} // namespace testing + +#define MOCK_METHOD(...) \ + GMOCK_INTERNAL_WARNING_PUSH() \ + GMOCK_INTERNAL_WARNING_CLANG(ignored, "-Wunused-member-function") \ + GMOCK_PP_VARIADIC_CALL(GMOCK_INTERNAL_MOCK_METHOD_ARG_, __VA_ARGS__) \ + GMOCK_INTERNAL_WARNING_POP() + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_1(...) \ + GMOCK_INTERNAL_WRONG_ARITY(__VA_ARGS__) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_2(...) \ + GMOCK_INTERNAL_WRONG_ARITY(__VA_ARGS__) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_3(_Ret, _MethodName, _Args) \ + GMOCK_INTERNAL_MOCK_METHOD_ARG_4(_Ret, _MethodName, _Args, ()) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_4(_Ret, _MethodName, _Args, _Spec) \ + GMOCK_INTERNAL_ASSERT_PARENTHESIS(_Args); \ + GMOCK_INTERNAL_ASSERT_PARENTHESIS(_Spec); \ + GMOCK_INTERNAL_ASSERT_VALID_SIGNATURE( \ + GMOCK_PP_NARG0 _Args, GMOCK_INTERNAL_SIGNATURE(_Ret, _Args)); \ + GMOCK_INTERNAL_ASSERT_VALID_SPEC(_Spec) \ + GMOCK_INTERNAL_MOCK_METHOD_IMPL( \ + GMOCK_PP_NARG0 _Args, _MethodName, GMOCK_INTERNAL_HAS_CONST(_Spec), \ + GMOCK_INTERNAL_HAS_OVERRIDE(_Spec), GMOCK_INTERNAL_HAS_FINAL(_Spec), \ + GMOCK_INTERNAL_GET_NOEXCEPT_SPEC(_Spec), \ + GMOCK_INTERNAL_GET_CALLTYPE_SPEC(_Spec), \ + GMOCK_INTERNAL_GET_REF_SPEC(_Spec), \ + (GMOCK_INTERNAL_SIGNATURE(_Ret, _Args))) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_5(...) \ + GMOCK_INTERNAL_WRONG_ARITY(__VA_ARGS__) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_6(...) \ + GMOCK_INTERNAL_WRONG_ARITY(__VA_ARGS__) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_7(...) \ + GMOCK_INTERNAL_WRONG_ARITY(__VA_ARGS__) + +#define GMOCK_INTERNAL_WRONG_ARITY(...) \ + static_assert( \ + false, \ + "MOCK_METHOD must be called with 3 or 4 arguments. _Ret, " \ + "_MethodName, _Args and optionally _Spec. _Args and _Spec must be " \ + "enclosed in parentheses. If _Ret is a type with unprotected commas, " \ + "it must also be enclosed in parentheses.") + +#define GMOCK_INTERNAL_ASSERT_PARENTHESIS(_Tuple) \ + static_assert( \ + GMOCK_PP_IS_ENCLOSED_PARENS(_Tuple), \ + GMOCK_PP_STRINGIZE(_Tuple) " should be enclosed in parentheses.") + +#define GMOCK_INTERNAL_ASSERT_VALID_SIGNATURE(_N, ...) \ + static_assert( \ + std::is_function<__VA_ARGS__>::value, \ + "Signature must be a function type, maybe return type contains " \ + "unprotected comma."); \ + static_assert( \ + ::testing::tuple_size::ArgumentTuple>::value == _N, \ + "This method does not take " GMOCK_PP_STRINGIZE( \ + _N) " arguments. Parenthesize all types with unprotected commas.") + +#define GMOCK_INTERNAL_ASSERT_VALID_SPEC(_Spec) \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_ASSERT_VALID_SPEC_ELEMENT, ~, _Spec) + +#define GMOCK_INTERNAL_MOCK_METHOD_IMPL(_N, _MethodName, _Constness, \ + _Override, _Final, _NoexceptSpec, \ + _CallType, _RefSpec, _Signature) \ + typename ::testing::internal::Function::Result \ + GMOCK_INTERNAL_EXPAND(_CallType) \ + _MethodName(GMOCK_PP_REPEAT(GMOCK_INTERNAL_PARAMETER, _Signature, _N)) \ + GMOCK_PP_IF(_Constness, const, ) _RefSpec _NoexceptSpec \ + GMOCK_PP_IF(_Override, override, ) GMOCK_PP_IF(_Final, final, ) { \ + GMOCK_MOCKER_(_N, _Constness, _MethodName) \ + .SetOwnerAndName(this, #_MethodName); \ + return GMOCK_MOCKER_(_N, _Constness, _MethodName) \ + .Invoke(GMOCK_PP_REPEAT(GMOCK_INTERNAL_FORWARD_ARG, _Signature, _N)); \ + } \ + ::testing::MockSpec gmock_##_MethodName( \ + GMOCK_PP_REPEAT(GMOCK_INTERNAL_MATCHER_PARAMETER, _Signature, _N)) \ + GMOCK_PP_IF(_Constness, const, ) _RefSpec { \ + GMOCK_MOCKER_(_N, _Constness, _MethodName).RegisterOwner(this); \ + return GMOCK_MOCKER_(_N, _Constness, _MethodName) \ + .With(GMOCK_PP_REPEAT(GMOCK_INTERNAL_MATCHER_ARGUMENT, , _N)); \ + } \ + ::testing::MockSpec gmock_##_MethodName( \ + const ::testing::internal::WithoutMatchers&, \ + GMOCK_PP_IF(_Constness, const, )::testing::internal::Function< \ + GMOCK_PP_REMOVE_PARENS(_Signature)>*) const _RefSpec _NoexceptSpec { \ + return ::testing::internal::ThisRefAdjuster::Adjust(*this) \ + .gmock_##_MethodName(GMOCK_PP_REPEAT( \ + GMOCK_INTERNAL_A_MATCHER_ARGUMENT, _Signature, _N)); \ + } \ + mutable ::testing::FunctionMocker \ + GMOCK_MOCKER_(_N, _Constness, _MethodName) + +#define GMOCK_INTERNAL_EXPAND(...) __VA_ARGS__ + +// Valid modifiers. +#define GMOCK_INTERNAL_HAS_CONST(_Tuple) \ + GMOCK_PP_HAS_COMMA(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_DETECT_CONST, ~, _Tuple)) + +#define GMOCK_INTERNAL_HAS_OVERRIDE(_Tuple) \ + GMOCK_PP_HAS_COMMA( \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_DETECT_OVERRIDE, ~, _Tuple)) + +#define GMOCK_INTERNAL_HAS_FINAL(_Tuple) \ + GMOCK_PP_HAS_COMMA(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_DETECT_FINAL, ~, _Tuple)) + +#define GMOCK_INTERNAL_GET_NOEXCEPT_SPEC(_Tuple) \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_NOEXCEPT_SPEC_IF_NOEXCEPT, ~, _Tuple) + +#define GMOCK_INTERNAL_NOEXCEPT_SPEC_IF_NOEXCEPT(_i, _, _elem) \ + GMOCK_PP_IF( \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_NOEXCEPT(_i, _, _elem)), \ + _elem, ) + +#define GMOCK_INTERNAL_GET_CALLTYPE_SPEC(_Tuple) \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_CALLTYPE_SPEC_IF_CALLTYPE, ~, _Tuple) + +#define GMOCK_INTERNAL_CALLTYPE_SPEC_IF_CALLTYPE(_i, _, _elem) \ + GMOCK_PP_IF( \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_CALLTYPE(_i, _, _elem)), \ + GMOCK_PP_CAT(GMOCK_INTERNAL_UNPACK_, _elem), ) + +#define GMOCK_INTERNAL_GET_REF_SPEC(_Tuple) \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_REF_SPEC_IF_REF, ~, _Tuple) + +#define GMOCK_INTERNAL_REF_SPEC_IF_REF(_i, _, _elem) \ + GMOCK_PP_IF(GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_REF(_i, _, _elem)), \ + GMOCK_PP_CAT(GMOCK_INTERNAL_UNPACK_, _elem), ) + +#ifdef GMOCK_INTERNAL_STRICT_SPEC_ASSERT +#define GMOCK_INTERNAL_ASSERT_VALID_SPEC_ELEMENT(_i, _, _elem) \ + static_assert( \ + ::testing::internal::ValidateSpec(GMOCK_PP_STRINGIZE(_elem)), \ + "Token \'" GMOCK_PP_STRINGIZE( \ + _elem) "\' cannot be recognized as a valid specification " \ + "modifier. Is a ',' missing?"); +#else +#define GMOCK_INTERNAL_ASSERT_VALID_SPEC_ELEMENT(_i, _, _elem) \ + static_assert( \ + (GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_CONST(_i, _, _elem)) + \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_OVERRIDE(_i, _, _elem)) + \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_FINAL(_i, _, _elem)) + \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_NOEXCEPT(_i, _, _elem)) + \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_REF(_i, _, _elem)) + \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_CALLTYPE(_i, _, _elem))) == 1, \ + GMOCK_PP_STRINGIZE( \ + _elem) " cannot be recognized as a valid specification modifier."); +#endif // GMOCK_INTERNAL_STRICT_SPEC_ASSERT + +// Modifiers implementation. +#define GMOCK_INTERNAL_DETECT_CONST(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_CONST_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_CONST_I_const , + +#define GMOCK_INTERNAL_DETECT_OVERRIDE(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_OVERRIDE_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_OVERRIDE_I_override , + +#define GMOCK_INTERNAL_DETECT_FINAL(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_FINAL_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_FINAL_I_final , + +#define GMOCK_INTERNAL_DETECT_NOEXCEPT(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_NOEXCEPT_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_NOEXCEPT_I_noexcept , + +#define GMOCK_INTERNAL_DETECT_REF(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_REF_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_REF_I_ref , + +#define GMOCK_INTERNAL_UNPACK_ref(x) x + +#define GMOCK_INTERNAL_DETECT_CALLTYPE(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_CALLTYPE_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_CALLTYPE_I_Calltype , + +#define GMOCK_INTERNAL_UNPACK_Calltype(...) __VA_ARGS__ + +// Note: The use of `identity_t` here allows _Ret to represent return types that +// would normally need to be specified in a different way. For example, a method +// returning a function pointer must be written as +// +// fn_ptr_return_t (*method(method_args_t...))(fn_ptr_args_t...) +// +// But we only support placing the return type at the beginning. To handle this, +// we wrap all calls in identity_t, so that a declaration will be expanded to +// +// identity_t method(method_args_t...) +// +// This allows us to work around the syntactic oddities of function/method +// types. +#define GMOCK_INTERNAL_SIGNATURE(_Ret, _Args) \ + ::testing::internal::identity_t( \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_GET_TYPE, _, _Args)) + +#define GMOCK_INTERNAL_GET_TYPE(_i, _, _elem) \ + GMOCK_PP_COMMA_IF(_i) \ + GMOCK_PP_IF(GMOCK_PP_IS_BEGIN_PARENS(_elem), GMOCK_PP_REMOVE_PARENS, \ + GMOCK_PP_IDENTITY) \ + (_elem) + +#define GMOCK_INTERNAL_PARAMETER(_i, _Signature, _) \ + GMOCK_PP_COMMA_IF(_i) \ + GMOCK_INTERNAL_ARG_O(_i, GMOCK_PP_REMOVE_PARENS(_Signature)) \ + gmock_a##_i + +#define GMOCK_INTERNAL_FORWARD_ARG(_i, _Signature, _) \ + GMOCK_PP_COMMA_IF(_i) \ + ::std::forward(gmock_a##_i) + +#define GMOCK_INTERNAL_MATCHER_PARAMETER(_i, _Signature, _) \ + GMOCK_PP_COMMA_IF(_i) \ + GMOCK_INTERNAL_MATCHER_O(_i, GMOCK_PP_REMOVE_PARENS(_Signature)) \ + gmock_a##_i + +#define GMOCK_INTERNAL_MATCHER_ARGUMENT(_i, _1, _2) \ + GMOCK_PP_COMMA_IF(_i) \ + gmock_a##_i + +#define GMOCK_INTERNAL_A_MATCHER_ARGUMENT(_i, _Signature, _) \ + GMOCK_PP_COMMA_IF(_i) \ + ::testing::A() + +#define GMOCK_INTERNAL_ARG_O(_i, ...) \ + typename ::testing::internal::Function<__VA_ARGS__>::template Arg<_i>::type + +#define GMOCK_INTERNAL_MATCHER_O(_i, ...) \ + const ::testing::Matcher::template Arg<_i>::type>& + +#define MOCK_METHOD0(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 0, __VA_ARGS__) +#define MOCK_METHOD1(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 1, __VA_ARGS__) +#define MOCK_METHOD2(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 2, __VA_ARGS__) +#define MOCK_METHOD3(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 3, __VA_ARGS__) +#define MOCK_METHOD4(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 4, __VA_ARGS__) +#define MOCK_METHOD5(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 5, __VA_ARGS__) +#define MOCK_METHOD6(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 6, __VA_ARGS__) +#define MOCK_METHOD7(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 7, __VA_ARGS__) +#define MOCK_METHOD8(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 8, __VA_ARGS__) +#define MOCK_METHOD9(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 9, __VA_ARGS__) +#define MOCK_METHOD10(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, , m, 10, __VA_ARGS__) + +#define MOCK_CONST_METHOD0(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 0, __VA_ARGS__) +#define MOCK_CONST_METHOD1(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 1, __VA_ARGS__) +#define MOCK_CONST_METHOD2(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 2, __VA_ARGS__) +#define MOCK_CONST_METHOD3(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 3, __VA_ARGS__) +#define MOCK_CONST_METHOD4(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 4, __VA_ARGS__) +#define MOCK_CONST_METHOD5(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 5, __VA_ARGS__) +#define MOCK_CONST_METHOD6(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 6, __VA_ARGS__) +#define MOCK_CONST_METHOD7(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 7, __VA_ARGS__) +#define MOCK_CONST_METHOD8(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 8, __VA_ARGS__) +#define MOCK_CONST_METHOD9(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 9, __VA_ARGS__) +#define MOCK_CONST_METHOD10(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 10, __VA_ARGS__) + +#define MOCK_METHOD0_T(m, ...) MOCK_METHOD0(m, __VA_ARGS__) +#define MOCK_METHOD1_T(m, ...) MOCK_METHOD1(m, __VA_ARGS__) +#define MOCK_METHOD2_T(m, ...) MOCK_METHOD2(m, __VA_ARGS__) +#define MOCK_METHOD3_T(m, ...) MOCK_METHOD3(m, __VA_ARGS__) +#define MOCK_METHOD4_T(m, ...) MOCK_METHOD4(m, __VA_ARGS__) +#define MOCK_METHOD5_T(m, ...) MOCK_METHOD5(m, __VA_ARGS__) +#define MOCK_METHOD6_T(m, ...) MOCK_METHOD6(m, __VA_ARGS__) +#define MOCK_METHOD7_T(m, ...) MOCK_METHOD7(m, __VA_ARGS__) +#define MOCK_METHOD8_T(m, ...) MOCK_METHOD8(m, __VA_ARGS__) +#define MOCK_METHOD9_T(m, ...) MOCK_METHOD9(m, __VA_ARGS__) +#define MOCK_METHOD10_T(m, ...) MOCK_METHOD10(m, __VA_ARGS__) + +#define MOCK_CONST_METHOD0_T(m, ...) MOCK_CONST_METHOD0(m, __VA_ARGS__) +#define MOCK_CONST_METHOD1_T(m, ...) MOCK_CONST_METHOD1(m, __VA_ARGS__) +#define MOCK_CONST_METHOD2_T(m, ...) MOCK_CONST_METHOD2(m, __VA_ARGS__) +#define MOCK_CONST_METHOD3_T(m, ...) MOCK_CONST_METHOD3(m, __VA_ARGS__) +#define MOCK_CONST_METHOD4_T(m, ...) MOCK_CONST_METHOD4(m, __VA_ARGS__) +#define MOCK_CONST_METHOD5_T(m, ...) MOCK_CONST_METHOD5(m, __VA_ARGS__) +#define MOCK_CONST_METHOD6_T(m, ...) MOCK_CONST_METHOD6(m, __VA_ARGS__) +#define MOCK_CONST_METHOD7_T(m, ...) MOCK_CONST_METHOD7(m, __VA_ARGS__) +#define MOCK_CONST_METHOD8_T(m, ...) MOCK_CONST_METHOD8(m, __VA_ARGS__) +#define MOCK_CONST_METHOD9_T(m, ...) MOCK_CONST_METHOD9(m, __VA_ARGS__) +#define MOCK_CONST_METHOD10_T(m, ...) MOCK_CONST_METHOD10(m, __VA_ARGS__) + +#define MOCK_METHOD0_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 0, __VA_ARGS__) +#define MOCK_METHOD1_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 1, __VA_ARGS__) +#define MOCK_METHOD2_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 2, __VA_ARGS__) +#define MOCK_METHOD3_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 3, __VA_ARGS__) +#define MOCK_METHOD4_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 4, __VA_ARGS__) +#define MOCK_METHOD5_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 5, __VA_ARGS__) +#define MOCK_METHOD6_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 6, __VA_ARGS__) +#define MOCK_METHOD7_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 7, __VA_ARGS__) +#define MOCK_METHOD8_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 8, __VA_ARGS__) +#define MOCK_METHOD9_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 9, __VA_ARGS__) +#define MOCK_METHOD10_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 10, __VA_ARGS__) + +#define MOCK_CONST_METHOD0_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 0, __VA_ARGS__) +#define MOCK_CONST_METHOD1_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 1, __VA_ARGS__) +#define MOCK_CONST_METHOD2_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 2, __VA_ARGS__) +#define MOCK_CONST_METHOD3_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 3, __VA_ARGS__) +#define MOCK_CONST_METHOD4_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 4, __VA_ARGS__) +#define MOCK_CONST_METHOD5_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 5, __VA_ARGS__) +#define MOCK_CONST_METHOD6_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 6, __VA_ARGS__) +#define MOCK_CONST_METHOD7_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 7, __VA_ARGS__) +#define MOCK_CONST_METHOD8_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 8, __VA_ARGS__) +#define MOCK_CONST_METHOD9_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 9, __VA_ARGS__) +#define MOCK_CONST_METHOD10_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 10, __VA_ARGS__) + +#define MOCK_METHOD0_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD0_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD1_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD1_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD2_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD2_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD3_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD3_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD4_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD4_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD5_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD5_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD6_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD6_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD7_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD7_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD8_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD8_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD9_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD9_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD10_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD10_WITH_CALLTYPE(ct, m, __VA_ARGS__) + +#define MOCK_CONST_METHOD0_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD0_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD1_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD1_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD2_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD2_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD3_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD3_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD4_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD4_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD5_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD5_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD6_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD6_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD7_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD7_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD8_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD8_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD9_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD9_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD10_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD10_WITH_CALLTYPE(ct, m, __VA_ARGS__) + +#define GMOCK_INTERNAL_MOCK_METHODN(constness, ct, Method, args_num, ...) \ + GMOCK_INTERNAL_ASSERT_VALID_SIGNATURE( \ + args_num, ::testing::internal::identity_t<__VA_ARGS__>); \ + GMOCK_INTERNAL_MOCK_METHOD_IMPL( \ + args_num, Method, GMOCK_PP_NARG0(constness), 0, 0, , ct, , \ + (::testing::internal::identity_t<__VA_ARGS__>)) + +#define GMOCK_MOCKER_(arity, constness, Method) \ + GTEST_CONCAT_TOKEN_(gmock##constness##arity##_##Method##_, __LINE__) + +#endif // GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_FUNCTION_MOCKER_H_ diff --git a/3rdparty/googletest-1.13.0/googlemock/include/gmock/gmock-matchers.h b/3rdparty/googletest-1.13.0/googlemock/include/gmock/gmock-matchers.h new file mode 100644 index 0000000000000000000000000000000000000000..9e634f7f1c47b267c2e2ddf5392445cb393ab669 --- /dev/null +++ b/3rdparty/googletest-1.13.0/googlemock/include/gmock/gmock-matchers.h @@ -0,0 +1,5620 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Google Mock - a framework for writing C++ mock classes. +// +// The MATCHER* family of macros can be used in a namespace scope to +// define custom matchers easily. +// +// Basic Usage +// =========== +// +// The syntax +// +// MATCHER(name, description_string) { statements; } +// +// defines a matcher with the given name that executes the statements, +// which must return a bool to indicate if the match succeeds. Inside +// the statements, you can refer to the value being matched by 'arg', +// and refer to its type by 'arg_type'. +// +// The description string documents what the matcher does, and is used +// to generate the failure message when the match fails. Since a +// MATCHER() is usually defined in a header file shared by multiple +// C++ source files, we require the description to be a C-string +// literal to avoid possible side effects. It can be empty, in which +// case we'll use the sequence of words in the matcher name as the +// description. +// +// For example: +// +// MATCHER(IsEven, "") { return (arg % 2) == 0; } +// +// allows you to write +// +// // Expects mock_foo.Bar(n) to be called where n is even. +// EXPECT_CALL(mock_foo, Bar(IsEven())); +// +// or, +// +// // Verifies that the value of some_expression is even. +// EXPECT_THAT(some_expression, IsEven()); +// +// If the above assertion fails, it will print something like: +// +// Value of: some_expression +// Expected: is even +// Actual: 7 +// +// where the description "is even" is automatically calculated from the +// matcher name IsEven. +// +// Argument Type +// ============= +// +// Note that the type of the value being matched (arg_type) is +// determined by the context in which you use the matcher and is +// supplied to you by the compiler, so you don't need to worry about +// declaring it (nor can you). This allows the matcher to be +// polymorphic. For example, IsEven() can be used to match any type +// where the value of "(arg % 2) == 0" can be implicitly converted to +// a bool. In the "Bar(IsEven())" example above, if method Bar() +// takes an int, 'arg_type' will be int; if it takes an unsigned long, +// 'arg_type' will be unsigned long; and so on. +// +// Parameterizing Matchers +// ======================= +// +// Sometimes you'll want to parameterize the matcher. For that you +// can use another macro: +// +// MATCHER_P(name, param_name, description_string) { statements; } +// +// For example: +// +// MATCHER_P(HasAbsoluteValue, value, "") { return abs(arg) == value; } +// +// will allow you to write: +// +// EXPECT_THAT(Blah("a"), HasAbsoluteValue(n)); +// +// which may lead to this message (assuming n is 10): +// +// Value of: Blah("a") +// Expected: has absolute value 10 +// Actual: -9 +// +// Note that both the matcher description and its parameter are +// printed, making the message human-friendly. +// +// In the matcher definition body, you can write 'foo_type' to +// reference the type of a parameter named 'foo'. For example, in the +// body of MATCHER_P(HasAbsoluteValue, value) above, you can write +// 'value_type' to refer to the type of 'value'. +// +// We also provide MATCHER_P2, MATCHER_P3, ..., up to MATCHER_P$n to +// support multi-parameter matchers. +// +// Describing Parameterized Matchers +// ================================= +// +// The last argument to MATCHER*() is a string-typed expression. The +// expression can reference all of the matcher's parameters and a +// special bool-typed variable named 'negation'. When 'negation' is +// false, the expression should evaluate to the matcher's description; +// otherwise it should evaluate to the description of the negation of +// the matcher. For example, +// +// using testing::PrintToString; +// +// MATCHER_P2(InClosedRange, low, hi, +// std::string(negation ? "is not" : "is") + " in range [" + +// PrintToString(low) + ", " + PrintToString(hi) + "]") { +// return low <= arg && arg <= hi; +// } +// ... +// EXPECT_THAT(3, InClosedRange(4, 6)); +// EXPECT_THAT(3, Not(InClosedRange(2, 4))); +// +// would generate two failures that contain the text: +// +// Expected: is in range [4, 6] +// ... +// Expected: is not in range [2, 4] +// +// If you specify "" as the description, the failure message will +// contain the sequence of words in the matcher name followed by the +// parameter values printed as a tuple. For example, +// +// MATCHER_P2(InClosedRange, low, hi, "") { ... } +// ... +// EXPECT_THAT(3, InClosedRange(4, 6)); +// EXPECT_THAT(3, Not(InClosedRange(2, 4))); +// +// would generate two failures that contain the text: +// +// Expected: in closed range (4, 6) +// ... +// Expected: not (in closed range (2, 4)) +// +// Types of Matcher Parameters +// =========================== +// +// For the purpose of typing, you can view +// +// MATCHER_Pk(Foo, p1, ..., pk, description_string) { ... } +// +// as shorthand for +// +// template +// FooMatcherPk +// Foo(p1_type p1, ..., pk_type pk) { ... } +// +// When you write Foo(v1, ..., vk), the compiler infers the types of +// the parameters v1, ..., and vk for you. If you are not happy with +// the result of the type inference, you can specify the types by +// explicitly instantiating the template, as in Foo(5, +// false). As said earlier, you don't get to (or need to) specify +// 'arg_type' as that's determined by the context in which the matcher +// is used. You can assign the result of expression Foo(p1, ..., pk) +// to a variable of type FooMatcherPk. This +// can be useful when composing matchers. +// +// While you can instantiate a matcher template with reference types, +// passing the parameters by pointer usually makes your code more +// readable. If, however, you still want to pass a parameter by +// reference, be aware that in the failure message generated by the +// matcher you will see the value of the referenced object but not its +// address. +// +// Explaining Match Results +// ======================== +// +// Sometimes the matcher description alone isn't enough to explain why +// the match has failed or succeeded. For example, when expecting a +// long string, it can be very helpful to also print the diff between +// the expected string and the actual one. To achieve that, you can +// optionally stream additional information to a special variable +// named result_listener, whose type is a pointer to class +// MatchResultListener: +// +// MATCHER_P(EqualsLongString, str, "") { +// if (arg == str) return true; +// +// *result_listener << "the difference: " +/// << DiffStrings(str, arg); +// return false; +// } +// +// Overloading Matchers +// ==================== +// +// You can overload matchers with different numbers of parameters: +// +// MATCHER_P(Blah, a, description_string1) { ... } +// MATCHER_P2(Blah, a, b, description_string2) { ... } +// +// Caveats +// ======= +// +// When defining a new matcher, you should also consider implementing +// MatcherInterface or using MakePolymorphicMatcher(). These +// approaches require more work than the MATCHER* macros, but also +// give you more control on the types of the value being matched and +// the matcher parameters, which may leads to better compiler error +// messages when the matcher is used wrong. They also allow +// overloading matchers based on parameter types (as opposed to just +// based on the number of parameters). +// +// MATCHER*() can only be used in a namespace scope as templates cannot be +// declared inside of a local class. +// +// More Information +// ================ +// +// To learn more about using these macros, please search for 'MATCHER' +// on +// https://github.com/google/googletest/blob/main/docs/gmock_cook_book.md +// +// This file also implements some commonly used argument matchers. More +// matchers can be defined by the user implementing the +// MatcherInterface interface if necessary. +// +// See googletest/include/gtest/gtest-matchers.h for the definition of class +// Matcher, class MatcherInterface, and others. + +// IWYU pragma: private, include "gmock/gmock.h" +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_MATCHERS_H_ +#define GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_MATCHERS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include +#include + +#include "gmock/internal/gmock-internal-utils.h" +#include "gmock/internal/gmock-port.h" +#include "gmock/internal/gmock-pp.h" +#include "gtest/gtest.h" + +// MSVC warning C5046 is new as of VS2017 version 15.8. +#if defined(_MSC_VER) && _MSC_VER >= 1915 +#define GMOCK_MAYBE_5046_ 5046 +#else +#define GMOCK_MAYBE_5046_ +#endif + +GTEST_DISABLE_MSC_WARNINGS_PUSH_( + 4251 GMOCK_MAYBE_5046_ /* class A needs to have dll-interface to be used by + clients of class B */ + /* Symbol involving type with internal linkage not defined */) + +namespace testing { + +// To implement a matcher Foo for type T, define: +// 1. a class FooMatcherImpl that implements the +// MatcherInterface interface, and +// 2. a factory function that creates a Matcher object from a +// FooMatcherImpl*. +// +// The two-level delegation design makes it possible to allow a user +// to write "v" instead of "Eq(v)" where a Matcher is expected, which +// is impossible if we pass matchers by pointers. It also eases +// ownership management as Matcher objects can now be copied like +// plain values. + +// A match result listener that stores the explanation in a string. +class StringMatchResultListener : public MatchResultListener { + public: + StringMatchResultListener() : MatchResultListener(&ss_) {} + + // Returns the explanation accumulated so far. + std::string str() const { return ss_.str(); } + + // Clears the explanation accumulated so far. + void Clear() { ss_.str(""); } + + private: + ::std::stringstream ss_; + + StringMatchResultListener(const StringMatchResultListener&) = delete; + StringMatchResultListener& operator=(const StringMatchResultListener&) = + delete; +}; + +// Anything inside the 'internal' namespace IS INTERNAL IMPLEMENTATION +// and MUST NOT BE USED IN USER CODE!!! +namespace internal { + +// The MatcherCastImpl class template is a helper for implementing +// MatcherCast(). We need this helper in order to partially +// specialize the implementation of MatcherCast() (C++ allows +// class/struct templates to be partially specialized, but not +// function templates.). + +// This general version is used when MatcherCast()'s argument is a +// polymorphic matcher (i.e. something that can be converted to a +// Matcher but is not one yet; for example, Eq(value)) or a value (for +// example, "hello"). +template +class MatcherCastImpl { + public: + static Matcher Cast(const M& polymorphic_matcher_or_value) { + // M can be a polymorphic matcher, in which case we want to use + // its conversion operator to create Matcher. Or it can be a value + // that should be passed to the Matcher's constructor. + // + // We can't call Matcher(polymorphic_matcher_or_value) when M is a + // polymorphic matcher because it'll be ambiguous if T has an implicit + // constructor from M (this usually happens when T has an implicit + // constructor from any type). + // + // It won't work to unconditionally implicit_cast + // polymorphic_matcher_or_value to Matcher because it won't trigger + // a user-defined conversion from M to T if one exists (assuming M is + // a value). + return CastImpl(polymorphic_matcher_or_value, + std::is_convertible>{}, + std::is_convertible{}); + } + + private: + template + static Matcher CastImpl(const M& polymorphic_matcher_or_value, + std::true_type /* convertible_to_matcher */, + std::integral_constant) { + // M is implicitly convertible to Matcher, which means that either + // M is a polymorphic matcher or Matcher has an implicit constructor + // from M. In both cases using the implicit conversion will produce a + // matcher. + // + // Even if T has an implicit constructor from M, it won't be called because + // creating Matcher would require a chain of two user-defined conversions + // (first to create T from M and then to create Matcher from T). + return polymorphic_matcher_or_value; + } + + // M can't be implicitly converted to Matcher, so M isn't a polymorphic + // matcher. It's a value of a type implicitly convertible to T. Use direct + // initialization to create a matcher. + static Matcher CastImpl(const M& value, + std::false_type /* convertible_to_matcher */, + std::true_type /* convertible_to_T */) { + return Matcher(ImplicitCast_(value)); + } + + // M can't be implicitly converted to either Matcher or T. Attempt to use + // polymorphic matcher Eq(value) in this case. + // + // Note that we first attempt to perform an implicit cast on the value and + // only fall back to the polymorphic Eq() matcher afterwards because the + // latter calls bool operator==(const Lhs& lhs, const Rhs& rhs) in the end + // which might be undefined even when Rhs is implicitly convertible to Lhs + // (e.g. std::pair vs. std::pair). + // + // We don't define this method inline as we need the declaration of Eq(). + static Matcher CastImpl(const M& value, + std::false_type /* convertible_to_matcher */, + std::false_type /* convertible_to_T */); +}; + +// This more specialized version is used when MatcherCast()'s argument +// is already a Matcher. This only compiles when type T can be +// statically converted to type U. +template +class MatcherCastImpl> { + public: + static Matcher Cast(const Matcher& source_matcher) { + return Matcher(new Impl(source_matcher)); + } + + private: + class Impl : public MatcherInterface { + public: + explicit Impl(const Matcher& source_matcher) + : source_matcher_(source_matcher) {} + + // We delegate the matching logic to the source matcher. + bool MatchAndExplain(T x, MatchResultListener* listener) const override { + using FromType = typename std::remove_cv::type>::type>::type; + using ToType = typename std::remove_cv::type>::type>::type; + // Do not allow implicitly converting base*/& to derived*/&. + static_assert( + // Do not trigger if only one of them is a pointer. That implies a + // regular conversion and not a down_cast. + (std::is_pointer::type>::value != + std::is_pointer::type>::value) || + std::is_same::value || + !std::is_base_of::value, + "Can't implicitly convert from to "); + + // Do the cast to `U` explicitly if necessary. + // Otherwise, let implicit conversions do the trick. + using CastType = + typename std::conditional::value, + T&, U>::type; + + return source_matcher_.MatchAndExplain(static_cast(x), + listener); + } + + void DescribeTo(::std::ostream* os) const override { + source_matcher_.DescribeTo(os); + } + + void DescribeNegationTo(::std::ostream* os) const override { + source_matcher_.DescribeNegationTo(os); + } + + private: + const Matcher source_matcher_; + }; +}; + +// This even more specialized version is used for efficiently casting +// a matcher to its own type. +template +class MatcherCastImpl> { + public: + static Matcher Cast(const Matcher& matcher) { return matcher; } +}; + +// Template specialization for parameterless Matcher. +template +class MatcherBaseImpl { + public: + MatcherBaseImpl() = default; + + template + operator ::testing::Matcher() const { // NOLINT(runtime/explicit) + return ::testing::Matcher(new + typename Derived::template gmock_Impl()); + } +}; + +// Template specialization for Matcher with parameters. +template