#!/usr/bin/env python3 # Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of NVIDIA CORPORATION nor the names of its # contributors may be used to endorse or promote products derived # from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import base64 import gzip import zlib from urllib.parse import quote import gevent import gevent.pool import rapidjson as json from geventhttpclient import HTTPClient from geventhttpclient.url import URL from tritonclient.utils import raise_error from .._client import InferenceServerClientBase from .._request import Request from ._infer_result import InferResult from ._utils import _get_inference_request, _get_query_string, _raise_if_error class InferAsyncRequest: """An object of InferAsyncRequest class is used to describe a handle to an ongoing asynchronous inference request. Parameters ---------- greenlet : gevent.Greenlet The greenlet object which will provide the results. For further details about greenlets refer http://www.gevent.org/api/gevent.greenlet.html. verbose : bool If True generate verbose output. Default value is False. """ def __init__(self, greenlet, verbose=False): self._greenlet = greenlet self._verbose = verbose def get_result(self, block=True, timeout=None): """Get the results of the associated asynchronous inference. Parameters ---------- block : bool If block is True, the function will wait till the corresponding response is received from the server. Default value is True. timeout : int The maximum wait time for the function. This setting is ignored if the block is set False. Default is None, which means the function will block indefinitely till the corresponding response is received. Returns ------- InferResult The object holding the result of the async inference. Raises ------ InferenceServerException If server fails to perform inference or failed to respond within specified timeout. """ try: response = self._greenlet.get(block=block, timeout=timeout) except gevent.Timeout as e: raise_error("failed to obtain inference response") _raise_if_error(response) return InferResult(response, self._verbose) class InferenceServerClient(InferenceServerClientBase): """An InferenceServerClient object is used to perform any kind of communication with the InferenceServer using http protocol. None of the methods are thread safe. The object is intended to be used by a single thread and simultaneously calling different methods with different threads is not supported and will cause undefined behavior. Parameters ---------- url : str The inference server name, port and optional base path in the following format: host:port/, e.g. 'localhost:8000'. verbose : bool If True generate verbose output. Default value is False. concurrency : int The number of connections to create for this client. Default value is 1. connection_timeout : float The timeout value for the connection. Default value is 60.0 sec. network_timeout : float The timeout value for the network. Default value is 60.0 sec max_greenlets : int Determines the maximum allowed number of worker greenlets for handling asynchronous inference requests. Default value is None, which means there will be no restriction on the number of greenlets created. ssl : bool If True, channels the requests to encrypted https scheme. Some improper settings may cause connection to prematurely terminate with an unsuccessful handshake. See `ssl_context_factory` option for using secure default settings. Default value for this option is False. ssl_options : dict Any options supported by `ssl.wrap_socket` specified as dictionary. The argument is ignored if 'ssl' is specified False. ssl_context_factory : SSLContext callable It must be a callbable that returns a SSLContext. Set to `gevent.ssl.create_default_context` to use contexts with secure default settings. This should most likely resolve connection issues in a secure way. The default value for this option is None which directly wraps the socket with the options provided via `ssl_options`. The argument is ignored if 'ssl' is specified False. insecure : bool If True, then does not match the host name with the certificate. Default value is False. The argument is ignored if 'ssl' is specified False. Raises ------ Exception If unable to create a client. """ def __init__( self, url, verbose=False, concurrency=1, connection_timeout=60.0, network_timeout=60.0, max_greenlets=None, ssl=False, ssl_options=None, ssl_context_factory=None, insecure=False, ): super().__init__() if url.startswith("http://") or url.startswith("https://"): raise_error("url should not include the scheme") scheme = "https://" if ssl else "http://" self._parsed_url = URL(scheme + url) self._base_uri = self._parsed_url.request_uri.rstrip("/") self._client_stub = HTTPClient.from_url( self._parsed_url, concurrency=concurrency, connection_timeout=connection_timeout, network_timeout=network_timeout, ssl_options=ssl_options, ssl_context_factory=ssl_context_factory, insecure=insecure, ) self._pool = gevent.pool.Pool(max_greenlets) self._verbose = verbose def __enter__(self): return self def __exit__(self, type, value, traceback): self.close() def __del__(self): self.close() def close(self): """Close the client. Any future calls to server will result in an Error. """ self._pool.join() self._client_stub.close() def _get(self, request_uri, headers, query_params): """Issues the GET request to the server Parameters ---------- request_uri: str The request URI to be used in GET request. headers: dict Additional HTTP headers to include in the request. query_params: dict Optional url query parameters to use in network transaction. Returns ------- geventhttpclient.response.HTTPSocketPoolResponse The response from server. """ request = Request(headers) self._call_plugin(request) # Update the headers based on plugin invocation headers = request.headers self._validate_headers(headers) if self._base_uri is not None: request_uri = self._base_uri + "/" + request_uri if query_params is not None: request_uri = request_uri + "?" + _get_query_string(query_params) if self._verbose: print("GET {}, headers {}".format(request_uri, headers)) if headers is not None: response = self._client_stub.get(request_uri, headers=headers) else: response = self._client_stub.get(request_uri) if self._verbose: print(response) return response def _post(self, request_uri, request_body, headers, query_params): """Issues the POST request to the server Parameters ---------- request_uri: str The request URI to be used in POST request. request_body: str The body of the request headers: dict Additional HTTP headers to include in the request. query_params: dict Optional url query parameters to use in network transaction. Returns ------- geventhttpclient.response.HTTPSocketPoolResponse The response from server. """ request = Request(headers) self._call_plugin(request) # Update the headers based on plugin invocation headers = request.headers self._validate_headers(headers) if self._base_uri is not None: request_uri = self._base_uri + "/" + request_uri if query_params is not None: request_uri = request_uri + "?" + _get_query_string(query_params) if self._verbose: print("POST {}, headers {}\n{}".format(request_uri, headers, request_body)) if headers is not None: response = self._client_stub.post( request_uri=request_uri, body=request_body, headers=headers ) else: response = self._client_stub.post( request_uri=request_uri, body=request_body ) if self._verbose: print(response) return response def _validate_headers(self, headers): """Checks for any unsupported HTTP headers before processing a request. Parameters ---------- headers: dict HTTP headers to validate before processing the request. Raises ------ InferenceServerException If an unsupported HTTP header is included in a request. """ if not headers: return # HTTP headers are case-insensitive, so force lowercase for comparison headers_lowercase = {k.lower(): v for k, v in headers.items()} # The python client lirary (and geventhttpclient) do not encode request # data based on "Transfer-Encoding" header, so reject this header if # included. Other libraries may do this encoding under the hood. # The python client library does expose special arguments to support # some "Content-Encoding" headers. if "transfer-encoding" in headers_lowercase: raise_error( "Unsupported HTTP header: 'Transfer-Encoding' is not " "supported in the Python client library. Use raw HTTP " "request libraries or the C++ client instead for this " "header." ) def is_server_live(self, headers=None, query_params=None): """Contact the inference server and get liveness. Parameters ---------- headers: dict Optional dictionary specifying additional HTTP headers to include in the request. query_params: dict Optional url query parameters to use in network transaction. Returns ------- bool True if server is live, False if server is not live. Raises ------ Exception If unable to get liveness. """ request_uri = "v2/health/live" response = self._get( request_uri=request_uri, headers=headers, query_params=query_params ) return response.status_code == 200 def is_server_ready(self, headers=None, query_params=None): """Contact the inference server and get readiness. Parameters ---------- headers: dict Optional dictionary specifying additional HTTP headers to include in the request. query_params: dict Optional url query parameters to use in network transaction. Returns ------- bool True if server is ready, False if server is not ready. Raises ------ Exception If unable to get readiness. """ request_uri = "v2/health/ready" response = self._get( request_uri=request_uri, headers=headers, query_params=query_params ) return response.status_code == 200 def is_model_ready( self, model_name, model_version="", headers=None, query_params=None ): """Contact the inference server and get the readiness of specified model. Parameters ---------- model_name: str The name of the model to check for readiness. model_version: str The version of the model to check for readiness. The default value is an empty string which means then the server will choose a version based on the model and internal policy. headers: dict Optional dictionary specifying additional HTTP headers to include in the request. query_params: dict Optional url query parameters to use in network transaction. Returns ------- bool True if the model is ready, False if not ready. Raises ------ Exception If unable to get model readiness. """ if type(model_version) != str: raise_error("model version must be a string") if model_version != "": request_uri = "v2/models/{}/versions/{}/ready".format( quote(model_name), model_version ) else: request_uri = "v2/models/{}/ready".format(quote(model_name)) response = self._get( request_uri=request_uri, headers=headers, query_params=query_params ) return response.status_code == 200 def get_server_metadata(self, headers=None, query_params=None): """Contact the inference server and get its metadata. Parameters ---------- headers: dict Optional dictionary specifying additional HTTP headers to include in the request. query_params: dict Optional url query parameters to use in network transaction. Returns ------- dict The JSON dict holding the metadata. Raises ------ InferenceServerException If unable to get server metadata. """ request_uri = "v2" response = self._get( request_uri=request_uri, headers=headers, query_params=query_params ) _raise_if_error(response) content = response.read() if self._verbose: print(content) return json.loads(content) def get_model_metadata( self, model_name, model_version="", headers=None, query_params=None ): """Contact the inference server and get the metadata for specified model. Parameters ---------- model_name: str The name of the model model_version: str The version of the model to get metadata. The default value is an empty string which means then the server will choose a version based on the model and internal policy. headers: dict Optional dictionary specifying additional HTTP headers to include in the request query_params: dict Optional url query parameters to use in network transaction Returns ------- dict The JSON dict holding the metadata. Raises ------ InferenceServerException If unable to get model metadata. """ if type(model_version) != str: raise_error("model version must be a string") if model_version != "": request_uri = "v2/models/{}/versions/{}".format( quote(model_name), model_version ) else: request_uri = "v2/models/{}".format(quote(model_name)) response = self._get( request_uri=request_uri, headers=headers, query_params=query_params ) _raise_if_error(response) content = response.read() if self._verbose: print(content) return json.loads(content) def get_model_config( self, model_name, model_version="", headers=None, query_params=None ): """Contact the inference server and get the configuration for specified model. Parameters ---------- model_name: str The name of the model model_version: str The version of the model to get configuration. The default value is an empty string which means then the server will choose a version based on the model and internal policy. headers: dict Optional dictionary specifying additional HTTP headers to include in the request query_params: dict Optional url query parameters to use in network transaction Returns ------- dict The JSON dict holding the model config. Raises ------ InferenceServerException If unable to get model configuration. """ if model_version != "": request_uri = "v2/models/{}/versions/{}/config".format( quote(model_name), model_version ) else: request_uri = "v2/models/{}/config".format(quote(model_name)) response = self._get( request_uri=request_uri, headers=headers, query_params=query_params ) _raise_if_error(response) content = response.read() if self._verbose: print(content) return json.loads(content) def get_model_repository_index(self, headers=None, query_params=None): """Get the index of model repository contents Parameters ---------- headers: dict Optional dictionary specifying additional HTTP headers to include in the request query_params: dict Optional url query parameters to use in network transaction Returns ------- dict The JSON dict holding the model repository index. Raises ------ InferenceServerException If unable to get the repository index. """ request_uri = "v2/repository/index" response = self._post( request_uri=request_uri, request_body="", headers=headers, query_params=query_params, ) _raise_if_error(response) content = response.read() if self._verbose: print(content) return json.loads(content) def load_model( self, model_name, headers=None, query_params=None, config=None, files=None ): """Request the inference server to load or reload specified model. Parameters ---------- model_name : str The name of the model to be loaded. headers: dict Optional dictionary specifying additional HTTP headers to include in the request. query_params: dict Optional url query parameters to use in network transaction. config: str Optional JSON representation of a model config provided for the load request, if provided, this config will be used for loading the model. files: dict Optional dictionary specifying file path (with "file:" prefix) in the override model directory to the file content as bytes. The files will form the model directory that the model will be loaded from. If specified, 'config' must be provided to be the model configuration of the override model directory. Raises ------ InferenceServerException If unable to load the model. """ request_uri = "v2/repository/models/{}/load".format(quote(model_name)) load_request = {} if config is not None: if "parameters" not in load_request: load_request["parameters"] = {} load_request["parameters"]["config"] = config if files is not None: for path, content in files.items(): if "parameters" not in load_request: load_request["parameters"] = {} load_request["parameters"][path] = base64.b64encode(content) response = self._post( request_uri=request_uri, request_body=json.dumps(load_request), headers=headers, query_params=query_params, ) _raise_if_error(response) if self._verbose: print("Loaded model '{}'".format(model_name)) def unload_model( self, model_name, headers=None, query_params=None, unload_dependents=False ): """Request the inference server to unload specified model. Parameters ---------- model_name : str The name of the model to be unloaded. headers: dict Optional dictionary specifying additional HTTP headers to include in the request query_params: dict Optional url query parameters to use in network transaction unload_dependents : bool Whether the dependents of the model should also be unloaded. Raises ------ InferenceServerException If unable to unload the model. """ request_uri = "v2/repository/models/{}/unload".format(quote(model_name)) unload_request = {"parameters": {"unload_dependents": unload_dependents}} response = self._post( request_uri=request_uri, request_body=json.dumps(unload_request), headers=headers, query_params=query_params, ) _raise_if_error(response) if self._verbose: print("Loaded model '{}'".format(model_name)) def get_inference_statistics( self, model_name="", model_version="", headers=None, query_params=None ): """Get the inference statistics for the specified model name and version. Parameters ---------- model_name : str The name of the model to get statistics. The default value is an empty string, which means statistics of all models will be returned. model_version: str The version of the model to get inference statistics. The default value is an empty string which means then the server will return the statistics of all available model versions. headers: dict Optional dictionary specifying additional HTTP headers to include in the request. query_params: dict Optional url query parameters to use in network transaction Returns ------- dict The JSON dict holding the model inference statistics. Raises ------ InferenceServerException If unable to get the model inference statistics. """ if model_name != "": if type(model_version) != str: raise_error("model version must be a string") if model_version != "": request_uri = "v2/models/{}/versions/{}/stats".format( quote(model_name), model_version ) else: request_uri = "v2/models/{}/stats".format(quote(model_name)) else: request_uri = "v2/models/stats" response = self._get( request_uri=request_uri, headers=headers, query_params=query_params ) _raise_if_error(response) content = response.read() if self._verbose: print(content) return json.loads(content) def update_trace_settings( self, model_name=None, settings={}, headers=None, query_params=None ): """Update the trace settings for the specified model name, or global trace settings if model name is not given. Returns the trace settings after the update. Parameters ---------- model_name : str The name of the model to update trace settings. Specifying None or empty string will update the global trace settings. The default value is None. settings: dict The new trace setting values. Only the settings listed will be updated. If a trace setting is listed in the dictionary with a value of 'None', that setting will be cleared. headers: dict Optional dictionary specifying additional HTTP headers to include in the request. query_params: dict Optional url query parameters to use in network transaction Returns ------- dict The JSON dict holding the updated trace settings. Raises ------ InferenceServerException If unable to update the trace settings. """ if (model_name is not None) and (model_name != ""): request_uri = "v2/models/{}/trace/setting".format(quote(model_name)) else: request_uri = "v2/trace/setting" response = self._post( request_uri=request_uri, request_body=json.dumps(settings), headers=headers, query_params=query_params, ) _raise_if_error(response) content = response.read() if self._verbose: print(content) return json.loads(content) def get_trace_settings(self, model_name=None, headers=None, query_params=None): """Get the trace settings for the specified model name, or global trace settings if model name is not given Parameters ---------- model_name : str The name of the model to get trace settings. Specifying None or empty string will return the global trace settings. The default value is None. headers: dict Optional dictionary specifying additional HTTP headers to include in the request. query_params: dict Optional url query parameters to use in network transaction Returns ------- dict The JSON dict holding the trace settings. Raises ------ InferenceServerException If unable to get the trace settings. """ if (model_name is not None) and (model_name != ""): request_uri = "v2/models/{}/trace/setting".format(quote(model_name)) else: request_uri = "v2/trace/setting" response = self._get( request_uri=request_uri, headers=headers, query_params=query_params ) _raise_if_error(response) content = response.read() if self._verbose: print(content) return json.loads(content) def update_log_settings(self, settings, headers=None, query_params=None): """Update the global log settings of the Triton server. Parameters ---------- settings: dict The new log setting values. Only the settings listed will be updated. headers: dict Optional dictionary specifying additional HTTP headers to include in the request. query_params: dict Optional url query parameters to use in network transaction Returns ------- dict The JSON dict holding the updated log settings. Raises ------ InferenceServerException If unable to update the log settings. """ request_uri = "v2/logging" response = self._post( request_uri=request_uri, request_body=json.dumps(settings), headers=headers, query_params=query_params, ) _raise_if_error(response) content = response.read() if self._verbose: print(content) return json.loads(content) def get_log_settings(self, headers=None, query_params=None): """Get the global log settings for the Triton server Parameters ---------- headers: dict Optional dictionary specifying additional HTTP headers to include in the request. query_params: dict Optional url query parameters to use in network transaction Returns ------- dict The JSON dict holding the log settings. Raises ------ InferenceServerException If unable to get the log settings. """ request_uri = "v2/logging" response = self._get( request_uri=request_uri, headers=headers, query_params=query_params ) _raise_if_error(response) content = response.read() if self._verbose: print(content) return json.loads(content) def get_system_shared_memory_status( self, region_name="", headers=None, query_params=None ): """Request system shared memory status from the server. Parameters ---------- region_name : str The name of the region to query status. The default value is an empty string, which means that the status of all active system shared memory will be returned. headers: dict Optional dictionary specifying additional HTTP headers to include in the request query_params: dict Optional url query parameters to use in network transaction Returns ------- dict The JSON dict holding system shared memory status. Raises ------ InferenceServerException If unable to get the status of specified shared memory. """ if region_name != "": request_uri = "v2/systemsharedmemory/region/{}/status".format( quote(region_name) ) else: request_uri = "v2/systemsharedmemory/status" response = self._get( request_uri=request_uri, headers=headers, query_params=query_params ) _raise_if_error(response) content = response.read() if self._verbose: print(content) return json.loads(content) def register_system_shared_memory( self, name, key, byte_size, offset=0, headers=None, query_params=None ): """Request the server to register a system shared memory with the following specification. Parameters ---------- name : str The name of the region to register. key : str The key of the underlying memory object that contains the system shared memory region. byte_size : int The size of the system shared memory region, in bytes. offset : int Offset, in bytes, within the underlying memory object to the start of the system shared memory region. The default value is zero. headers: dict Optional dictionary specifying additional HTTP headers to include in the request query_params: dict Optional url query parameters to use in network transaction Raises ------ InferenceServerException If unable to register the specified system shared memory. """ request_uri = "v2/systemsharedmemory/region/{}/register".format(quote(name)) register_request = {"key": key, "offset": offset, "byte_size": byte_size} request_body = json.dumps(register_request) response = self._post( request_uri=request_uri, request_body=request_body, headers=headers, query_params=query_params, ) _raise_if_error(response) if self._verbose: print("Registered system shared memory with name '{}'".format(name)) def unregister_system_shared_memory(self, name="", headers=None, query_params=None): """Request the server to unregister a system shared memory with the specified name. Parameters ---------- name : str The name of the region to unregister. The default value is empty string which means all the system shared memory regions will be unregistered. headers: dict Optional dictionary specifying additional HTTP headers to include in the request query_params: dict Optional url query parameters to use in network transaction Raises ------ InferenceServerException If unable to unregister the specified system shared memory region. """ if name != "": request_uri = "v2/systemsharedmemory/region/{}/unregister".format( quote(name) ) else: request_uri = "v2/systemsharedmemory/unregister" response = self._post( request_uri=request_uri, request_body="", headers=headers, query_params=query_params, ) _raise_if_error(response) if self._verbose: if name != "": print("Unregistered system shared memory with name '{}'".format(name)) else: print("Unregistered all system shared memory regions") def get_cuda_shared_memory_status( self, region_name="", headers=None, query_params=None ): """Request cuda shared memory status from the server. Parameters ---------- region_name : str The name of the region to query status. The default value is an empty string, which means that the status of all active cuda shared memory will be returned. headers: dict Optional dictionary specifying additional HTTP headers to include in the request query_params: dict Optional url query parameters to use in network transaction Returns ------- dict The JSON dict holding cuda shared memory status. Raises ------ InferenceServerException If unable to get the status of specified shared memory. """ if region_name != "": request_uri = "v2/cudasharedmemory/region/{}/status".format( quote(region_name) ) else: request_uri = "v2/cudasharedmemory/status" response = self._get( request_uri=request_uri, headers=headers, query_params=query_params ) _raise_if_error(response) content = response.read() if self._verbose: print(content) return json.loads(content) def register_cuda_shared_memory( self, name, raw_handle, device_id, byte_size, headers=None, query_params=None ): """Request the server to register a system shared memory with the following specification. Parameters ---------- name : str The name of the region to register. raw_handle : bytes The raw serialized cudaIPC handle in base64 encoding. device_id : int The GPU device ID on which the cudaIPC handle was created. byte_size : int The size of the cuda shared memory region, in bytes. headers: dict Optional dictionary specifying additional HTTP headers to include in the request query_params: dict Optional url query parameters to use in network transaction Raises ------ InferenceServerException If unable to register the specified cuda shared memory. """ request_uri = "v2/cudasharedmemory/region/{}/register".format(quote(name)) register_request = { "raw_handle": {"b64": raw_handle}, "device_id": device_id, "byte_size": byte_size, } request_body = json.dumps(register_request) response = self._post( request_uri=request_uri, request_body=request_body, headers=headers, query_params=query_params, ) _raise_if_error(response) if self._verbose: print("Registered cuda shared memory with name '{}'".format(name)) def unregister_cuda_shared_memory(self, name="", headers=None, query_params=None): """Request the server to unregister a cuda shared memory with the specified name. Parameters ---------- name : str The name of the region to unregister. The default value is empty string which means all the cuda shared memory regions will be unregistered. headers: dict Optional dictionary specifying additional HTTP headers to include in the request query_params: dict Optional url query parameters to use in network transaction Raises ------ InferenceServerException If unable to unregister the specified cuda shared memory region. """ if name != "": request_uri = "v2/cudasharedmemory/region/{}/unregister".format(quote(name)) else: request_uri = "v2/cudasharedmemory/unregister" response = self._post( request_uri=request_uri, request_body="", headers=headers, query_params=query_params, ) _raise_if_error(response) if self._verbose: if name != "": print("Unregistered cuda shared memory with name '{}'".format(name)) else: print("Unregistered all cuda shared memory regions") @staticmethod def generate_request_body( inputs, outputs=None, request_id="", sequence_id=0, sequence_start=False, sequence_end=False, priority=0, timeout=None, parameters=None, ): """Generate a request body for inference using the supplied 'inputs' requesting the outputs specified by 'outputs'. Parameters ---------- inputs : list A list of InferInput objects, each describing data for a input tensor required by the model. outputs : list A list of InferRequestedOutput objects, each describing how the output data must be returned. If not specified all outputs produced by the model will be returned using default settings. request_id: str Optional identifier for the request. If specified will be returned in the response. Default value is an empty string which means no request_id will be used. sequence_id : int or str The unique identifier for the sequence being represented by the object. A value of 0 or "" means that the request does not belong to a sequence. Default is 0. sequence_start: bool Indicates whether the request being added marks the start of the sequence. Default value is False. This argument is ignored if 'sequence_id' is 0. sequence_end: bool Indicates whether the request being added marks the end of the sequence. Default value is False. This argument is ignored if 'sequence_id' is 0. priority : int Indicates the priority of the request. Priority value zero indicates that the default priority level should be used (i.e. same behavior as not specifying the priority parameter). Lower value priorities indicate higher priority levels. Thus the highest priority level is indicated by setting the parameter to 1, the next highest is 2, etc. If not provided, the server will handle the request using default setting for the model. timeout : int The timeout value for the request, in microseconds. If the request cannot be completed within the time the server can take a model-specific action such as terminating the request. If not provided, the server will handle the request using default setting for the model. This option is only respected by the model that is configured with dynamic batching. See here for more details: https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#dynamic-batcher parameters: dict Optional fields to be included in the 'parameters' fields. Returns ------- Bytes The request body of the inference. Int The byte size of the inference request header in the request body. Returns None if the whole request body constitutes the request header. Raises ------ InferenceServerException If server fails to perform inference. """ return _get_inference_request( inputs=inputs, request_id=request_id, outputs=outputs, sequence_id=sequence_id, sequence_start=sequence_start, sequence_end=sequence_end, priority=priority, timeout=timeout, custom_parameters=parameters, ) @staticmethod def parse_response_body( response_body, verbose=False, header_length=None, content_encoding=None ): """Generate a InferResult object from the given 'response_body' Parameters ---------- response_body : bytes The inference response from the server verbose : bool If True generate verbose output. Default value is False. header_length : int The length of the inference header if the header does not occupy the whole response body. Default value is None. content_encoding : string The encoding of the response body if it is compressed. Default value is None. Returns ------- InferResult The InferResult object generated from the response body """ return InferResult.from_response_body( response_body, verbose, header_length, content_encoding ) def infer( self, model_name, inputs, model_version="", outputs=None, request_id="", sequence_id=0, sequence_start=False, sequence_end=False, priority=0, timeout=None, headers=None, query_params=None, request_compression_algorithm=None, response_compression_algorithm=None, parameters=None, ): """Run synchronous inference using the supplied 'inputs' requesting the outputs specified by 'outputs'. Parameters ---------- model_name: str The name of the model to run inference. inputs : list A list of InferInput objects, each describing data for a input tensor required by the model. model_version: str The version of the model to run inference. The default value is an empty string which means then the server will choose a version based on the model and internal policy. outputs : list A list of InferRequestedOutput objects, each describing how the output data must be returned. If not specified all outputs produced by the model will be returned using default settings. request_id: str Optional identifier for the request. If specified will be returned in the response. Default value is an empty string which means no request_id will be used. sequence_id : int The unique identifier for the sequence being represented by the object. Default value is 0 which means that the request does not belong to a sequence. sequence_start: bool Indicates whether the request being added marks the start of the sequence. Default value is False. This argument is ignored if 'sequence_id' is 0. sequence_end: bool Indicates whether the request being added marks the end of the sequence. Default value is False. This argument is ignored if 'sequence_id' is 0. priority : int Indicates the priority of the request. Priority value zero indicates that the default priority level should be used (i.e. same behavior as not specifying the priority parameter). Lower value priorities indicate higher priority levels. Thus the highest priority level is indicated by setting the parameter to 1, the next highest is 2, etc. If not provided, the server will handle the request using default setting for the model. timeout : int The timeout value for the request, in microseconds. If the request cannot be completed within the time the server can take a model-specific action such as terminating the request. If not provided, the server will handle the request using default setting for the model. This option is only respected by the model that is configured with dynamic batching. See here for more details: https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#dynamic-batcher headers: dict Optional dictionary specifying additional HTTP headers to include in the request. query_params: dict Optional url query parameters to use in network transaction. request_compression_algorithm : str Optional HTTP compression algorithm to use for the request body on client side. Currently supports "deflate", "gzip" and None. By default, no compression is used. response_compression_algorithm : str Optional HTTP compression algorithm to request for the response body. Note that the response may not be compressed if the server does not support the specified algorithm. Currently supports "deflate", "gzip" and None. By default, no compression is requested. parameters: dict Optional fields to be included in the 'parameters' fields. Returns ------- InferResult The object holding the result of the inference. Raises ------ InferenceServerException If server fails to perform inference. """ request_body, json_size = _get_inference_request( inputs=inputs, request_id=request_id, outputs=outputs, sequence_id=sequence_id, sequence_start=sequence_start, sequence_end=sequence_end, priority=priority, timeout=timeout, custom_parameters=parameters, ) if request_compression_algorithm == "gzip": if headers is None: headers = {} headers["Content-Encoding"] = "gzip" request_body = gzip.compress(request_body) elif request_compression_algorithm == "deflate": if headers is None: headers = {} headers["Content-Encoding"] = "deflate" # "Content-Encoding: deflate" actually means compressing in zlib structure # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding request_body = zlib.compress(request_body) if response_compression_algorithm == "gzip": if headers is None: headers = {} headers["Accept-Encoding"] = "gzip" elif response_compression_algorithm == "deflate": if headers is None: headers = {} headers["Accept-Encoding"] = "deflate" if json_size is not None: if headers is None: headers = {} headers["Inference-Header-Content-Length"] = json_size if type(model_version) != str: raise_error("model version must be a string") if model_version != "": request_uri = "v2/models/{}/versions/{}/infer".format( quote(model_name), model_version ) else: request_uri = "v2/models/{}/infer".format(quote(model_name)) response = self._post( request_uri=request_uri, request_body=request_body, headers=headers, query_params=query_params, ) _raise_if_error(response) return InferResult(response, self._verbose) def async_infer( self, model_name, inputs, model_version="", outputs=None, request_id="", sequence_id=0, sequence_start=False, sequence_end=False, priority=0, timeout=None, headers=None, query_params=None, request_compression_algorithm=None, response_compression_algorithm=None, parameters=None, ): """Run asynchronous inference using the supplied 'inputs' requesting the outputs specified by 'outputs'. Even though this call is non-blocking, however, the actual number of concurrent requests to the server will be limited by the 'concurrency' parameter specified while creating this client. In other words, if the inflight async_infer exceeds the specified 'concurrency', the delivery of the exceeding request(s) to server will be blocked till the slot is made available by retrieving the results of previously issued requests. Parameters ---------- model_name: str The name of the model to run inference. inputs : list A list of InferInput objects, each describing data for a input tensor required by the model. model_version: str The version of the model to run inference. The default value is an empty string which means then the server will choose a version based on the model and internal policy. outputs : list A list of InferRequestedOutput objects, each describing how the output data must be returned. If not specified all outputs produced by the model will be returned using default settings. request_id: str Optional identifier for the request. If specified will be returned in the response. Default value is 'None' which means no request_id will be used. sequence_id : int The unique identifier for the sequence being represented by the object. Default value is 0 which means that the request does not belong to a sequence. sequence_start: bool Indicates whether the request being added marks the start of the sequence. Default value is False. This argument is ignored if 'sequence_id' is 0. sequence_end: bool Indicates whether the request being added marks the end of the sequence. Default value is False. This argument is ignored if 'sequence_id' is 0. priority : int Indicates the priority of the request. Priority value zero indicates that the default priority level should be used (i.e. same behavior as not specifying the priority parameter). Lower value priorities indicate higher priority levels. Thus the highest priority level is indicated by setting the parameter to 1, the next highest is 2, etc. If not provided, the server will handle the request using default setting for the model. timeout : int The timeout value for the request, in microseconds. If the request cannot be completed within the time the server can take a model-specific action such as terminating the request. If not provided, the server will handle the request using default setting for the model. This option is only respected by the model that is configured with dynamic batching. See here for more details: https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#dynamic-batcher headers: dict Optional dictionary specifying additional HTTP headers to include in the request query_params: dict Optional url query parameters to use in network transaction. request_compression_algorithm : str Optional HTTP compression algorithm to use for the request body on client side. Currently supports "deflate", "gzip" and None. By default, no compression is used. response_compression_algorithm : str Optional HTTP compression algorithm to request for the response body. Note that the response may not be compressed if the server does not support the specified algorithm. Currently supports "deflate", "gzip" and None. By default, no compression is requested. parameters : dict Optional custom parameters to be included in the inference request. Returns ------- InferAsyncRequest object The handle to the asynchronous inference request. Raises ------ InferenceServerException If server fails to issue inference. """ def wrapped_post(request_uri, request_body, headers, query_params): return self._post(request_uri, request_body, headers, query_params) request_body, json_size = _get_inference_request( inputs=inputs, request_id=request_id, outputs=outputs, sequence_id=sequence_id, sequence_start=sequence_start, sequence_end=sequence_end, priority=priority, timeout=timeout, custom_parameters=parameters, ) if request_compression_algorithm == "gzip": if headers is None: headers = {} headers["Content-Encoding"] = "gzip" request_body = gzip.compress(request_body) elif request_compression_algorithm == "deflate": if headers is None: headers = {} headers["Content-Encoding"] = "deflate" # "Content-Encoding: deflate" actually means compressing in zlib structure # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding request_body = zlib.compress(request_body) if response_compression_algorithm == "gzip": if headers is None: headers = {} headers["Accept-Encoding"] = "gzip" elif response_compression_algorithm == "deflate": if headers is None: headers = {} headers["Accept-Encoding"] = "deflate" if json_size is not None: if headers is None: headers = {} headers["Inference-Header-Content-Length"] = json_size if type(model_version) != str: raise_error("model version must be a string") if model_version != "": request_uri = "v2/models/{}/versions/{}/infer".format( quote(model_name), model_version ) else: request_uri = "v2/models/{}/infer".format(quote(model_name)) g = self._pool.apply_async( wrapped_post, (request_uri, request_body, headers, query_params) ) # Schedule the greenlet to run in this loop iteration g.start() # Relinquish control to greenlet loop. Using non-zero # value to ensure the control is transferred to the # event loop. gevent.sleep(0.01) if self._verbose: verbose_message = "Sent request" if request_id != "": verbose_message = verbose_message + " '{}'".format(request_id) print(verbose_message) return InferAsyncRequest(g, self._verbose)